├── Readme.md ├── algos ├── __init__.py ├── base.py ├── buffer.py ├── collect.py └── ppo.py ├── commuication_net.py ├── env ├── Game.py ├── Game_RL.py ├── __init__.py ├── boxkey.py ├── historicalobs.py ├── keydistraction.py ├── randomboxkey.py └── singleroom.py ├── executive_net.py ├── img ├── always.gif ├── baseline.gif ├── framework.png ├── hard_code.gif └── ours.gif ├── main.py ├── mediator.py ├── planner.py ├── prompt └── task_info.json ├── requirements.txt ├── skill ├── __init__.py ├── base_skill.py ├── explore.py └── goto_goal.py └── utils ├── __init__.py ├── env.py ├── eval.py ├── format.py ├── global_param.py ├── log.py └── logo.py /Readme.md: -------------------------------------------------------------------------------- 1 | # [Enabling Intelligent Interactions between an Agent and an LLM: A Reinforcement Learning Approach](https://arxiv.org/pdf/2306.03604.pdf) 2 | ## Abstract 3 | Large language models (LLMs) encode a vast amount of world knowledge acquired from massive text datasets. Recent studies have demonstrated that LLMs can assist an algorithm agent in solving complex sequential decision making tasks in embodied environments by providing high-level instructions. However, interacting with LLMs can be time-consuming, as in many practical scenarios, they require a significant amount of storage space that can only be deployed on remote cloud server nodes. Additionally, using commercial LLMs can be costly since they may charge based on usage frequency. In this paper, we explore how to enable efficient and cost-effective interactions between the agent and an LLM. We propose a reinforcement learning based mediator model that determines when it is necessary to consult LLMs for high-level instructions to accomplish a target task. Experiments on 4 MiniGrid environments that entail planning sub-goals demonstrate that our method can learn to solve target tasks with only a few necessary interactions with an LLM, significantly reducing interaction costs in testing environments, compared with baseline methods. Experimental results also suggest that by learning a mediator model to interact with the LLM, the agent's performance becomes more robust against both exploratory and stochastic environments. 4 | 5 | llm4rl 6 | 7 | ## Purpose 8 | This repo is intended to serve as a foundation with which you can reproduce the results of the experiments detailed in our paper, [Enabling Efficient Interaction between an Algorithm Agent and LLMs: A Reinforcement Learning Approach](https://arxiv.org/pdf/2306.03604.pdf) 9 | 10 | ## How to run the LLM 11 | For the LLM's deployment, refer to instructions in its github codes. 12 | 13 | Here is an example of how to run Vicuna models in linux terminals: 14 | 1. Download necessary files and model weights from https://github.com/lm-sys/FastChat 15 | 2. Follow the commands provided to launch the API in separate terminals (such as tmux): 16 | python3 -m fastchat.serve.controller --host localhost --port ### Launch the controller 17 | python3 -m fastchat.serve.model_worker --model-name '' --model-path --controller http://localhost: --port --worker_address http://localhost: ### Launch the model worker 18 | python3 -m fastchat.serve.api --host --port ### Launch the API 19 | 3. In planner.py, set url to 'http://API_host:API_port/v1/chat/completions 20 | 21 | ## Running experiments 22 | 23 | ### Basics 24 | Any algorithm can be run from the main.py entry point. 25 | 26 | to train on a SimpleDoorKey environment, 27 | 28 | ```bash 29 | python main.py train --task SimpleDoorKey --save_name experiment01 30 | ``` 31 | 32 | to eval the trained model "experiment01" on a SimpleDoorKey environment, 33 | 34 | ```bash 35 | python main.py eval --task SimpleDoorKey --save_name experiment01 --show --record 36 | ``` 37 | 38 | to run other baseline, 39 | 40 | ```bash 41 | python main.py baseline --task SimpleDoorKey --save_name baseline 42 | python main.py random --task SimpleDoorKey --save_name random 43 | python main.py always --task SimpleDoorKey --save_name always 44 | ``` 45 | 46 | to train and eval RL_case, 47 | ```bash 48 | python main.py train_RL --task SimpleDoorKey --save_name RL 49 | python main.py eval_RL --task SimpleDoorKey --save_name RL 50 | ``` 51 | 52 | ## Logging details 53 | Tensorboard logging is enabled by default for all algorithms. The logger expects that you supply an argument named ```logdir```, containing the root directory you want to store your logfiles 54 | 55 | The resulting directory tree would look something like this: 56 | ``` 57 | log/ # directory with all of the saved models and tensorboard 58 | └── ppo # algorithm name 59 | └── simpledoorkey # environment name 60 | └── save_name # unique save name 61 | ├── acmodel.pt # actor and critic network for algo 62 | ├── events.out.tfevents # tensorboard binary file 63 | └── config.json # readable hyperparameters for this run 64 | ``` 65 | 66 | Using tensorboard makes it easy to compare experiments and resume training later on. 67 | 68 | To see live training progress 69 | 70 | Run ```$ tensorboard --logdir=log``` then navigate to ```http://localhost:6006/``` in your browser 71 | 72 | ## Environments: 73 | * `SimpleDoorKey` : The task of the agent is open the door in the maze with key 74 | * `KeyInBox` : The task of the agent is to toggle the door in the maze. Key is hidden is a box. 75 | * `RandomBoxKey` : The task of the agent is to toggle the door in the maze. The key is randomly put on the floor or in a box 76 | * `ColoredDoorKey` : The task of the agent is to toggle the door in the maze. The room contains multiple keys and only one exit door. The door can be unlocked only with the key of the same color. 77 | 78 | ## Algorithms: 79 | #### Currently implemented: 80 | * [PPO](https://arxiv.org/abs/1707.06347), VPG with ratio objective and with log likelihood objective 81 | * [Vicuna-7B-v1.1](https://huggingface.co/lmsys/vicuna-7b-v1.1), this is the LLM model we used in our experiment 82 | 83 | ## Demonstrations: 84 | #### Our approach: 85 | 86 | 87 | #### Hard-code baseline: 88 | 89 | 90 | #### Always baseline: 91 | 92 | 93 | ## Citation 94 | If you find [our work](https://arxiv.org/abs/2306.03604) useful, please kindly cite: 95 | ```bibtex 96 | @article{Hu2023enabling, 97 | title = {Enabling Intelligent Interactions between an Agent and an LLM: A Reinforcement Learning Approach}, 98 | author = {Hu, Bin and Zhao, Chenyang and Zhang, Pu and Zhou, Zihao and Yang, Yuanhang and Xu, Zenglin and Liu, Bin}, 99 | journal = {arXiv preprint arXiv:2306.03604}, 100 | year = {2023} 101 | } 102 | ``` 103 | 104 | 105 | ## Acknowledgements 106 | This work is supported by Exploratory Research Project (No.2022RC0AN02) of Zhejiang Lab. 107 | -------------------------------------------------------------------------------- /algos/__init__.py: -------------------------------------------------------------------------------- 1 | from .ppo import * 2 | from .buffer import * 3 | from .base import * -------------------------------------------------------------------------------- /algos/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : base.py 5 | @Time : 2023/05/18 09:42:14 6 | @Author : Hu Bin 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | 11 | 12 | from abc import ABC, abstractmethod 13 | 14 | class Base(ABC): 15 | """The base class for RL algorithms.""" 16 | 17 | def __init__(self, model, device, num_steps, lr, max_grad_norm, entropy_coef, value_loss_coef, num_worker): 18 | super().__init__() 19 | self.model = model 20 | self.device = device 21 | self.num_steps = num_steps 22 | self.lr = lr 23 | self.max_grad_norm = max_grad_norm 24 | self.entropy_coef = entropy_coef 25 | self.value_loss_coef = value_loss_coef 26 | self.num_worker = num_worker 27 | 28 | # self.model.to(self.device) 29 | # self.model.train() 30 | 31 | @abstractmethod 32 | def update_policy(self): 33 | pass -------------------------------------------------------------------------------- /algos/buffer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : buffer.py 5 | @Time : 2023/04/19 09:40:36 6 | @Author : Hu Bin 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | import numpy as np 11 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 12 | from torch.nn.utils.rnn import pad_sequence 13 | import torch 14 | 15 | 16 | def Merge_Buffers(buffers, device='cpu'): 17 | merged = Buffer(device=device) 18 | for buf in buffers: 19 | offset = len(merged) 20 | 21 | merged.states += buf.states 22 | merged.actions += buf.actions 23 | merged.rewards += buf.rewards 24 | merged.values += buf.values 25 | merged.returns += buf.returns 26 | merged.log_probs += buf.log_probs 27 | 28 | 29 | merged.ep_returns += buf.ep_returns 30 | merged.ep_lens += buf.ep_lens 31 | merged.ep_game_returns += buf.ep_game_returns 32 | merged.ep_interactions += buf.ep_interactions 33 | 34 | 35 | merged.traj_idx += [offset + i for i in buf.traj_idx[1:]] 36 | merged.ptr += buf.ptr 37 | 38 | return merged 39 | 40 | 41 | class Buffer: 42 | """ 43 | A buffer for storing trajectory data and calculating returns for the policy 44 | and critic updates. 45 | """ 46 | def __init__(self, gamma=0.99, lam=0.95, device='cpu'): 47 | self.states = [] 48 | self.actions = [] 49 | self.rewards = [] 50 | self.values = [] 51 | self.returns = [] 52 | self.log_probs = [] 53 | self.game_rewards = [] # without communication penalty 54 | 55 | 56 | self.ep_returns = [] # for logging 57 | self.ep_lens = [] 58 | self.ep_game_returns = [] # without communication penalty 59 | self.ep_interactions = [] # interactions number 60 | 61 | 62 | self.gamma, self.lam = gamma, lam 63 | self.device = device 64 | 65 | self.ptr = 0 66 | self.traj_idx = [0] 67 | 68 | def __len__(self): 69 | return self.ptr 70 | 71 | def store(self, state, action, reward, value, log_probs, game_reward=None): 72 | """ 73 | Append one timestep of agent-environment interaction to the buffer. 74 | """ 75 | # TODO: make sure these dimensions really make sense 76 | self.states += [state.squeeze(0)] 77 | self.actions += [action.squeeze(0)] 78 | self.rewards += [reward.squeeze(0)] 79 | self.values += [value.squeeze(0)] 80 | self.log_probs += [log_probs.squeeze(0)] 81 | if game_reward is not None: 82 | self.game_rewards += [game_reward.squeeze(0)] 83 | self.ptr += 1 84 | 85 | def finish_path(self, last_val=None, interactions=None): 86 | self.traj_idx += [self.ptr] 87 | rewards = self.rewards[self.traj_idx[-2]:self.traj_idx[-1]] 88 | 89 | 90 | 91 | returns = [] 92 | 93 | R = last_val.squeeze(0).copy() # Avoid copy? 94 | for reward in reversed(rewards): 95 | R = self.gamma * R + reward 96 | returns.insert(0, R) 97 | 98 | self.returns += returns 99 | 100 | self.ep_returns += [np.sum(rewards)] 101 | if self.game_rewards != []: 102 | game_rewards = self.game_rewards[self.traj_idx[-2]:self.traj_idx[-1]] 103 | self.ep_game_returns += [np.sum(game_rewards)] 104 | 105 | self.ep_lens += [len(rewards)] 106 | self.ep_interactions += [interactions] 107 | 108 | 109 | 110 | def get(self): 111 | return( 112 | np.array(self.states), 113 | np.array(self.actions), 114 | np.array(self.returns), 115 | np.array(self.values), 116 | np.array(self.log_probs) 117 | ) 118 | 119 | def sample(self, batch_size=64, recurrent=False): 120 | if recurrent: 121 | random_indices = SubsetRandomSampler(range(len(self.traj_idx)-1)) 122 | sampler = BatchSampler(random_indices, batch_size, drop_last=False) 123 | else: 124 | random_indices = SubsetRandomSampler(range(self.ptr)) 125 | sampler = BatchSampler(random_indices, batch_size, drop_last=True) 126 | 127 | observations, actions, returns, values, log_probs = map(torch.Tensor, self.get()) 128 | 129 | advantages = returns - values 130 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5) 131 | 132 | for indices in sampler: 133 | if recurrent: 134 | obs_batch = [observations[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices] 135 | action_batch = [actions[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices] 136 | return_batch = [returns[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices] 137 | advantage_batch = [advantages[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices] 138 | values_batch = [values[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices] 139 | mask = [torch.ones_like(r) for r in return_batch] 140 | log_prob_batch = [log_probs[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices] 141 | 142 | obs_batch = pad_sequence(obs_batch, batch_first=False) 143 | action_batch = pad_sequence(action_batch, batch_first=False) 144 | return_batch = pad_sequence(return_batch, batch_first=False) 145 | advantage_batch = pad_sequence(advantage_batch, batch_first=False) 146 | values_batch = pad_sequence(values_batch, batch_first=False) 147 | mask = pad_sequence(mask, batch_first=False) 148 | log_prob_batch = pad_sequence(log_prob_batch, batch_first=False) 149 | else: 150 | obs_batch = observations[indices] 151 | action_batch = actions[indices] 152 | return_batch = returns[indices] 153 | advantage_batch = advantages[indices] 154 | values_batch = values[indices] 155 | mask = torch.FloatTensor([1]) 156 | log_prob_batch = log_probs[indices] 157 | 158 | 159 | yield obs_batch.to(self.device), action_batch.to(self.device), return_batch.to(self.device), advantage_batch.to(self.device), values_batch.to(self.device), mask.to(self.device), log_prob_batch.to(self.device) 160 | 161 | 162 | -------------------------------------------------------------------------------- /algos/collect.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : collect.py 5 | @Time : 2023/04/19 10:14:11 6 | @Author : Hu Bin 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | 11 | #import ray 12 | from copy import deepcopy 13 | import torch 14 | import time 15 | from algos.buffer import Buffer 16 | import numpy as np 17 | from utils.env import WrapEnv 18 | 19 | # @ray.remote(num_gpus=1) 20 | class Collect_Worker: 21 | def __init__(self, policy, critic, device, gamma=0.99, lam=0.95): 22 | self.gamma = gamma 23 | self.lam = lam 24 | self.policy = deepcopy(policy) 25 | self.critic = deepcopy(critic) 26 | self.device = device 27 | 28 | 29 | def sync_policy(self, new_actor_params, new_critic_params): 30 | for p, new_p in zip(self.policy.parameters(), new_actor_params): 31 | p.data.copy_(new_p) 32 | 33 | for p, new_p in zip(self.critic.parameters(), new_critic_params): 34 | p.data.copy_(new_p) 35 | 36 | 37 | def collect(self, max_traj_len, min_steps, env_fn, anneal=1.0): 38 | env = WrapEnv(env_fn) 39 | with torch.no_grad(): 40 | memory = Buffer(self.gamma, self.lam) 41 | num_steps = 0 42 | # state = env.reset() 43 | while num_steps < min_steps: 44 | state = torch.Tensor(env.reset()) 45 | done = False 46 | value = 0 47 | traj_len = 0 48 | 49 | if hasattr(self.policy, 'init_hidden_state'): 50 | self.policy.init_hidden_state() 51 | 52 | if hasattr(self.critic, 'init_hidden_state'): 53 | self.critic.init_hidden_state() 54 | 55 | while not done and traj_len < max_traj_len: 56 | action = self.policy(state.to(self.device), deterministic=False, anneal=anneal).to("cpu") 57 | value = self.critic(state.to(self.device)).to("cpu") 58 | 59 | next_state, reward, done, _ = env.step(action.numpy()) 60 | memory.store(state.numpy(), action.numpy(), reward, value.numpy()) 61 | 62 | state = torch.Tensor(next_state) 63 | traj_len += 1 64 | num_steps += 1 65 | 66 | value = self.critic(state.to(self.device)).to("cpu") 67 | memory.finish_path(last_val=(not done) * value.numpy()) 68 | 69 | return memory 70 | 71 | def evaluate(self, max_traj_len, env_fn, trajs=1): 72 | torch.set_num_threads(1) 73 | env = WrapEnv(env_fn) 74 | with torch.no_grad(): 75 | ep_returns = [] 76 | for traj in range(trajs): 77 | state = torch.Tensor(env.reset()) 78 | done = False 79 | traj_len = 0 80 | ep_return = 0 81 | 82 | if hasattr(self.policy, 'init_hidden_state'): 83 | self.policy.init_hidden_state() 84 | 85 | while not done and traj_len < max_traj_len: 86 | action = self.policy(state.to(self.device), deterministic=False, anneal=1.0).to("cpu") 87 | 88 | next_state, reward, done, _ = env.step(action.numpy()) 89 | 90 | state = torch.Tensor(next_state) 91 | ep_return += reward 92 | traj_len += 1 93 | ep_returns += [ep_return] 94 | return np.mean(ep_returns) 95 | 96 | -------------------------------------------------------------------------------- /algos/ppo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : ppo.py 5 | @Time : 2023/05/17 11:23:57 6 | @Author : Hu Bin 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | 11 | from .base import Base 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | import os 16 | 17 | class PPO(Base): 18 | 19 | def __init__(self, model, device=None, save_path=None, num_steps=4096, lr=0.001, max_grad_norm=0.5, num_worker=4, epoch=3, batch_size=64, adam_eps=1e-8,entropy_coef=0.01,value_loss_coef=0.5,clip_eps=0.2,recurrent=False): 20 | super().__init__(model, device, num_steps, lr, max_grad_norm, entropy_coef, value_loss_coef, num_worker) 21 | 22 | self.lr = lr 23 | self.eps = adam_eps 24 | self.entropy_coeff = entropy_coef 25 | self.value_loss_coef= value_loss_coef 26 | self.clip = clip_eps 27 | self.minibatch_size = batch_size 28 | self.epochs = epoch 29 | self.num_steps = num_steps 30 | self.num_worker = num_worker 31 | self.grad_clip = max_grad_norm 32 | self.recurrent = recurrent 33 | self.device = device 34 | 35 | self.model = model.to(self.device) 36 | self.save_path = save_path 37 | 38 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr = self.lr, eps=self.eps) 39 | 40 | def save(self): 41 | try: 42 | os.makedirs(self.save_path) 43 | except OSError: 44 | pass 45 | filetype = ".pt" # pytorch model 46 | torch.save(self.model, os.path.join(self.save_path, "acmodel" + filetype)) 47 | #torch.save(self.policy, os.path.join(self.save_path, "actor" + filetype)) 48 | #torch.save(self.critic, os.path.join(self.save_path, "critic" + filetype)) 49 | 50 | def update_policy(self, buffer): 51 | losses = [] 52 | for _ in range(self.epochs): 53 | for batch in buffer.sample(self.minibatch_size, self.recurrent): 54 | obs_batch, action_batch, return_batch, advantage_batch, values_batch, mask, log_prob_batch = batch 55 | pdf, value = self.model(obs_batch) 56 | 57 | entropy_loss = (pdf.entropy() * mask).mean() 58 | 59 | ratio = torch.exp(pdf.log_prob(action_batch) - log_prob_batch) 60 | surr1 = ratio * advantage_batch * mask 61 | surr2 = torch.clamp(ratio, 1.0 - self.clip, 1.0 + self.clip) * advantage_batch * mask 62 | policy_loss = -torch.min(surr1, surr2).mean() 63 | 64 | value_clipped = values_batch + torch.clamp(value - values_batch, -self.clip, self.clip) 65 | surr1 = ((value - return_batch)*mask).pow(2) 66 | surr2 = ((value_clipped - return_batch)*mask).pow(2) 67 | value_loss = torch.max(surr1, surr2).mean() 68 | 69 | loss = policy_loss - self.entropy_coef * entropy_loss + self.value_loss_coef * value_loss 70 | 71 | # Update actor-critic 72 | self.optimizer.zero_grad() 73 | loss.backward() 74 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) 75 | self.optimizer.step() 76 | 77 | # Update log values 78 | losses.append([loss.item(), pdf.entropy().mean().item()]) 79 | mean_losses = np.mean(losses, axis=0) 80 | return mean_losses 81 | -------------------------------------------------------------------------------- /commuication_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : commuication_net.py 5 | @Time : 2023/05/16 16:34:11 6 | @Author : Hu Bin 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.distributions.categorical import Categorical 14 | import torch.nn.functional as F 15 | import numpy as np 16 | class Communication_Net(nn.Module): 17 | def __init__(self, obs_space, action_space): 18 | super().__init__() 19 | 20 | 21 | # Define image embedding 22 | self.image_conv = nn.Sequential( 23 | nn.Conv2d(obs_space['image'][-1], 16, (2, 2)), 24 | nn.ReLU(), 25 | nn.MaxPool2d((2, 2)), 26 | nn.Conv2d(16, 32, (2, 2)), 27 | nn.ReLU(), 28 | nn.Conv2d(32, 64, (2, 2)), 29 | nn.ReLU() 30 | ) 31 | n = obs_space["image"][0] 32 | m = obs_space["image"][1] 33 | self.embedding_size = ((n-1)//2-2)*((m-1)//2-2)*64 34 | 35 | # Define actor's model 36 | self.actor = nn.Sequential( 37 | nn.Linear(self.embedding_size, 64), 38 | nn.Tanh(), 39 | nn.Linear(64, action_space) 40 | ) 41 | # Define critic's model 42 | self.critic = nn.Sequential( 43 | nn.Linear(self.embedding_size, 64), 44 | nn.Tanh(), 45 | nn.Linear(64, 1) 46 | ) 47 | 48 | def forward(self, obs): 49 | x = obs.transpose(1, 3).transpose(2, 3) 50 | 51 | x = self.image_conv(x) 52 | 53 | x = x.reshape(x.shape[0], -1) 54 | embedding = x 55 | 56 | x = self.actor(embedding) 57 | dist = Categorical(logits=F.log_softmax(x, dim=1)) 58 | x = self.critic(embedding) 59 | value = x.squeeze(1) 60 | 61 | return dist, value 62 | 63 | def get_action(self, obs, deterministic=True): 64 | dist, _ = self.forward(obs) 65 | if deterministic: 66 | action = torch.argmax(dist.probs) 67 | else: 68 | action = dist.sample() 69 | return action 70 | 71 | class Multi_head_Communication_Net(Communication_Net): 72 | def __init__(self, obs_space, action_space, heads): 73 | super().__init__(obs_space, action_space) 74 | 75 | self.heads = heads 76 | self.hidden = nn.Sequential( 77 | nn.Linear(self.embedding_size, 64), 78 | nn.Tanh(), 79 | ) 80 | # Define multi-heads actor's model 81 | self.multi_heads_actor = [] 82 | for l in range(self.heads): 83 | actor_head = nn.Linear(64, action_space) 84 | self.multi_heads_actor.append(actor_head) 85 | 86 | def forward(self, obs_skill, deterministic=True): 87 | skill = [] 88 | obs = [] 89 | for i, os in enumerate(obs_skill): 90 | skill.append(os[-1]) 91 | obs.append(os[0]) 92 | obs = torch.stack(obs) 93 | skill = np.array(skill) 94 | x = obs.transpose(1, 3).transpose(2, 3) 95 | 96 | x = self.image_conv(x) 97 | 98 | x = x.reshape(x.shape[0], -1) 99 | 100 | embedding = x 101 | hidden = self.hidden(embedding) 102 | 103 | output = [] 104 | for i, actor in enumerate(self.multi_heads_actor): 105 | x = actor.to(obs.device)(hidden) 106 | #dist = Categorical(logits=F.log_softmax(x, dim=1)) 107 | output.append(x) 108 | 109 | x = self.critic(embedding) 110 | value = x.squeeze(1) 111 | 112 | dist_x = torch.zeros(len(skill), 2).to(embedding.device) 113 | for i in range(len(skill)): 114 | if skill[i] is None: 115 | dist_choose = np.random.choice(self.heads) 116 | elif skill[i][0]['action'] == 0: 117 | dist_choose = 0 118 | elif skill[i][0]['action'] == 1: 119 | dist_choose = 1 120 | elif skill[i][0]['action'] == 2: 121 | dist_choose = 2 122 | elif skill[i][0]['action'] == 4: 123 | dist_choose = 3 124 | dist_x[i] = output[dist_choose][0] 125 | dist = Categorical(logits=F.log_softmax(dist_x, dim=1)) 126 | return dist, value 127 | 128 | 129 | 130 | if __name__ == '__main__': 131 | pass 132 | -------------------------------------------------------------------------------- /env/Game.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : Game.py 5 | @Time : 2023/05/25 11:06:59 6 | @Author : Hu Bin 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | from planner import Planner 11 | from commuication_net import Communication_Net 12 | from executive_net import Executive_net 13 | import os, json, sys 14 | import utils 15 | import gymnasium as gym 16 | import env 17 | import algos 18 | import torch 19 | import numpy as np 20 | import time 21 | import skill 22 | import cv2 23 | 24 | prefix = os.getcwd() 25 | task_info_json = os.path.join(prefix, "prompt/task_info.json") 26 | class Game: 27 | def __init__(self, args, policy=None, run_seed=0): 28 | 29 | self.task = args.task 30 | self.device = args.device 31 | self.planner = Planner(self.task, run_seed) 32 | self.max_ep_len, self.decription, self.task_level, self.task_example, self.configurations = self.load_task_info(self.task) 33 | if args.seed is not None: 34 | self.seed = int(args.seed) 35 | else: 36 | self.seed = args.seed 37 | self.ask_lambda = args.ask_lambda 38 | self.gamma = args.gamma 39 | self.lam = args.lam 40 | self.batch_size = args.batch_size 41 | self.frame_stack = args.frame_stack 42 | 43 | self.env_fn = utils.make_env_fn(self.configurations, render_mode="rgb_array", frame_stack = args.frame_stack) 44 | self.agent_view_size = self.env_fn().agent_view_size 45 | 46 | obs_space, preprocess_obss = utils.get_obss_preprocessor(self.env_fn().observation_space) 47 | self.record_frames = args.record 48 | if policy is None: 49 | ## communication network for ask 50 | self.Communication_net = Communication_Net(obs_space, 2).to(self.device) 51 | else: 52 | self.Communication_net = policy.to(self.device) 53 | self.buffer = algos.Buffer(self.gamma, self.lam, self.device) 54 | 55 | self.logger = utils.create_logger(args) 56 | 57 | self.ppo_algo = algos.PPO(self.Communication_net, device=self.device, save_path=self.logger.dir, batch_size=self.batch_size) 58 | 59 | self.n_itr = args.n_itr 60 | self.total_steps = 0 61 | self.show_dialogue = args.show 62 | 63 | ## global_param for explore skill 64 | utils.global_param.init() 65 | self.traj_per_itr = args.traj_per_itr 66 | 67 | def load_task_info(self, task): 68 | with open(task_info_json, 'r') as f: 69 | task_info = json.load(f) 70 | task = task.lower() 71 | episode_length = int(task_info[task]["episode"]) 72 | task_description = task_info[task]['description'] 73 | task_example = task_info[task]['example'] 74 | task_level = task_info[task]['level'] 75 | task_configurations = task_info[task]['configurations'] 76 | return episode_length, task_description, task_level, task_example, task_configurations 77 | 78 | 79 | def reset(self): 80 | print(f"[INFO]: resetting the task: {self.task}") 81 | self.planner.initial_planning(self.decription, self.task_example) 82 | 83 | def train(self): 84 | start_time = time.time() 85 | for itr in range(self.n_itr): 86 | print("********** Iteration {} ************".format(itr)) 87 | print("time elapsed: {:.2f} s".format(time.time() - start_time)) 88 | 89 | ## collecting ## 90 | sample_start = time.time() 91 | buffer = [] 92 | for _ in range(self.traj_per_itr): 93 | buffer.append(self.collect(self.env_fn,seed=self.seed)) 94 | self.buffer = algos.Merge_Buffers(buffer,device=self.device) 95 | total_steps = len(self.buffer) 96 | samp_time = time.time() - sample_start 97 | print("{:.2f} s to collect {:6n} timesteps | {:3.2f}sample/s.".format(samp_time, total_steps, (total_steps)/samp_time)) 98 | self.total_steps += total_steps 99 | 100 | ## training ## 101 | optimizer_start = time.time() 102 | mean_losses = self.ppo_algo.update_policy(self.buffer) 103 | opt_time = time.time() - optimizer_start 104 | print("{:.2f} s to optimizer| loss {:6.3f}, entropy {:6.3f}.".format(opt_time, mean_losses[0], mean_losses[1])) 105 | 106 | ## eval_policy ## 107 | evaluate_start = time.time() 108 | eval_reward, eval_len, eval_interactions, eval_game_reward = self.eval(self.env_fn, trajs=1, seed=self.seed, show_dialogue=self.show_dialogue) 109 | eval_time = time.time() - evaluate_start 110 | print("{:.2f} s to evaluate.".format(eval_time)) 111 | if self.logger is not None: 112 | avg_eval_reward = eval_reward 113 | avg_batch_reward = np.mean(self.buffer.ep_returns) 114 | std_batch_reward = np.std(self.buffer.ep_returns) 115 | avg_batch_game_reward = np.mean(self.buffer.ep_game_returns) 116 | std_batch_game_reward = np.std(self.buffer.ep_game_returns) 117 | avg_ep_len = np.mean(self.buffer.ep_lens) 118 | success_rate = sum(i<100 for i in self.buffer.ep_lens) / 10. 119 | avg_eval_len = eval_len 120 | sys.stdout.write("-" * 37 + "\n") 121 | sys.stdout.write("| %15s | %15s |" % ('Timesteps', self.total_steps) + "\n") 122 | sys.stdout.write("| %15s | %15s |" % ('Return (test)', round(avg_eval_reward,2)) + "\n") 123 | sys.stdout.write("| %15s | %15s |" % ('Ep Lens (test) ', round(avg_eval_len,2)) + "\n") 124 | sys.stdout.write("| %15s | %15s |" % ('Ep Comm (test) ', round(eval_interactions,2)) + "\n") 125 | sys.stdout.write("| %15s | %15s |" % ('Return (batch)', round(avg_batch_reward,2)) + "\n") 126 | sys.stdout.write("| %15s | %15s |" % ('Mean Eplen', round(avg_ep_len,2)) + "\n") 127 | sys.stdout.write("-" * 37 + "\n") 128 | sys.stdout.flush() 129 | 130 | self.logger.add_scalar("Test/Return", avg_eval_reward, itr) 131 | self.logger.add_scalar("Test/Game Return", eval_game_reward, itr) 132 | self.logger.add_scalar("Test/Mean Eplen", avg_eval_len, itr) 133 | self.logger.add_scalar("Test/Comm", eval_interactions, itr) 134 | self.logger.add_scalar("Train/Return Mean", avg_batch_reward, itr) 135 | self.logger.add_scalar("Train/Return Std", std_batch_reward, itr) 136 | self.logger.add_scalar("Train/Game Return Mean", avg_batch_game_reward, itr) 137 | self.logger.add_scalar("Train/Game Return Std", std_batch_game_reward, itr) 138 | self.logger.add_scalar("Train/Eplen", avg_ep_len, itr) 139 | self.logger.add_scalar("Train/Success Rate", success_rate, itr) 140 | self.logger.add_scalar("Train/Loss", mean_losses[0], itr) 141 | self.logger.add_scalar("Train/Mean Entropy", mean_losses[1], itr) 142 | 143 | self.ppo_algo.save() 144 | 145 | 146 | def collect(self, env_fn, seed=None): 147 | # Do one agent-environment interaction 148 | env = utils.WrapEnv(env_fn) 149 | with torch.no_grad(): 150 | buffer = algos.Buffer(self.gamma, self.lam) 151 | obs = env.reset(seed) 152 | self.planner.reset() 153 | 154 | done = False 155 | skill_done = True 156 | traj_len = 0 157 | interactions = 0 158 | pre_skill = None 159 | if self.frame_stack >1: 160 | com_obs = obs 161 | else: 162 | his_obs = obs 163 | com_obs = obs - his_obs 164 | 165 | utils.global_param.set_value('exp', None) 166 | utils.global_param.set_value('explore_done', False) 167 | while not done and traj_len < self.max_ep_len: 168 | dist, value = self.Communication_net(torch.Tensor(com_obs).to(self.device)) 169 | ask_flag = dist.sample() 170 | log_probs = dist.log_prob(ask_flag) 171 | 172 | if skill_done or ask_flag: 173 | interactions += 1 174 | skill = self.planner(obs) 175 | # print(skill) 176 | if pre_skill == skill: # additional penalty term for repeat same skill 177 | repeat_feedback = np.array([.5]) 178 | elif pre_skill is None: 179 | repeat_feedback = np.array([0.]) 180 | else: 181 | repeat_feedback = np.array([-0.1]) 182 | pre_skill = skill 183 | self.Executive_net = Executive_net(skill,obs[0],self.agent_view_size) 184 | 185 | ## RL choose skill, and return action_list 186 | action, skill_done = self.Executive_net(obs[0]) 187 | 188 | ## one step do one action in action_list 189 | next_obs, reward, done, info = env.step(np.array([action])) 190 | 191 | comm_penalty = (self.ask_lambda + 0.1 * repeat_feedback) * (ask_flag.to("cpu").numpy() or skill_done) ## communication penalty 192 | comm_reward = reward - comm_penalty 193 | 194 | buffer.store(com_obs, ask_flag.to("cpu").numpy(), comm_reward, value.to("cpu").numpy(), log_probs.to("cpu").numpy(), reward) 195 | if self.frame_stack >1: 196 | obs = next_obs 197 | com_obs = obs 198 | else: 199 | his_obs = obs 200 | obs = next_obs 201 | com_obs = obs - his_obs 202 | traj_len += 1 203 | _, value = self.Communication_net(torch.Tensor(com_obs).to(self.device)) 204 | buffer.finish_path(last_val=(not done) * value.to("cpu").numpy(), interactions=interactions) 205 | 206 | return buffer 207 | 208 | 209 | 210 | 211 | def eval(self, env_fn, trajs=1, seed=None, show_dialogue=False): 212 | env = utils.WrapEnv(env_fn) 213 | with torch.no_grad(): 214 | ep_returns = [] 215 | ep_game_returns = [] 216 | ep_lens = [] 217 | ep_interactions = [] 218 | for traj in range(trajs): 219 | obs = env.reset(seed) 220 | #print(f"[INFO]: Evaluating the task is ", self.task) 221 | self.planner.reset(show_dialogue) 222 | 223 | if self.record_frames: 224 | img_array = [] 225 | dir_path = os.path.join(self.logger.dir, 'video') 226 | try: 227 | os.makedirs(dir_path) 228 | except OSError: 229 | pass 230 | txt_path = os.path.join(dir_path, str(seed) + '.txt') 231 | with open(txt_path, 'a+') as f: 232 | f.seek(0) 233 | f.truncate() 234 | 235 | done = False 236 | skill_done = True 237 | traj_len = 0 238 | pre_skill = None 239 | ep_return = 0 240 | ep_game_return = 0 241 | interactions = 0 242 | if self.frame_stack > 1: 243 | com_obs = obs 244 | else: 245 | his_obs = obs 246 | com_obs = obs - his_obs 247 | utils.global_param.set_value('exp', None) 248 | utils.global_param.set_value('explore_done', False) 249 | while not done and traj_len < self.max_ep_len: 250 | ask_flag = self.Communication_net.get_action(torch.Tensor(com_obs).to(self.device)) 251 | 252 | if skill_done or ask_flag: 253 | interactions += 1 254 | skill = self.planner(obs) 255 | # print(skill) 256 | if pre_skill == skill: # additional penalty term for repeat same skill 257 | repeat_feedback = np.array([.5]) 258 | elif pre_skill is None: 259 | repeat_feedback = np.array([0.]) 260 | else: 261 | repeat_feedback = np.array([-0.1]) 262 | pre_skill = skill 263 | self.Executive_net = Executive_net(skill,obs[0],self.agent_view_size) 264 | 265 | if self.record_frames: 266 | img = env.get_mask_render() 267 | text = str(traj_len) + ' ' + self.Executive_net.current_skill 268 | if skill_done or ask_flag == 1: 269 | text += ' (ask)' 270 | with open(txt_path, 'a+') as f: 271 | f.write('step:' + str(traj_len) + '\n' + self.planner.dialogue_user + '\n') 272 | 273 | cv2.putText(img, 274 | text, 275 | org=(10,300), 276 | fontFace=cv2.FONT_HERSHEY_TRIPLEX, 277 | fontScale=0.6, 278 | color=(255,255,255), 279 | thickness=1, 280 | lineType=cv2.LINE_AA) 281 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 282 | img_array.append(img) 283 | 284 | ## RL choose skill, and return action_list 285 | action, skill_done = self.Executive_net(obs[0]) 286 | ## one step do one action in action_list 287 | next_obs, reward, done, info = env.step(np.array([action])) 288 | 289 | if self.frame_stack >1: 290 | obs = next_obs 291 | com_obs = obs 292 | else: 293 | his_obs = obs 294 | obs = next_obs 295 | com_obs = obs - his_obs 296 | comm_penalty = (self.ask_lambda + 0.1 * repeat_feedback) * (ask_flag.to("cpu").numpy() or skill_done) ## communication penalty 297 | comm_reward = reward - comm_penalty 298 | #reward = 1.0*reward 299 | ep_return += comm_reward 300 | ep_game_return += 1.0*reward 301 | traj_len += 1 302 | if show_dialogue: 303 | if done: 304 | print("RL: Task Completed! \n") 305 | else: 306 | print("RL: Task Fail! \n") 307 | ep_returns += [ep_return] 308 | ep_game_returns += [ep_game_return] 309 | ep_lens += [traj_len] 310 | ep_interactions += [interactions] 311 | 312 | if self.record_frames: 313 | # save vedio 314 | height, width, layers = img.shape 315 | size = (width,height) 316 | video_path = os.path.join(dir_path, str(seed) + '.avi') 317 | out = cv2.VideoWriter(video_path, 318 | fourcc=cv2.VideoWriter_fourcc(*'DIVX'), 319 | fps=3, 320 | frameSize=size) 321 | 322 | for img in img_array: 323 | out.write(img) 324 | out.release() 325 | return np.mean(ep_returns), np.mean(ep_lens), np.mean(ep_interactions), np.mean(ep_game_returns) 326 | 327 | def baseline_eval(self, env_fn, trajs=1, seed=None, show_dialogue=False): 328 | env = utils.WrapEnv(env_fn) 329 | with torch.no_grad(): 330 | ep_returns = [] 331 | ep_game_returns = [] 332 | ep_lens = [] 333 | ep_interactions = [] 334 | for traj in range(trajs): 335 | obs = env.reset(seed) 336 | # print(f"[INFO]: Evaluating the task is ", self.task) 337 | self.planner.reset(show_dialogue) 338 | 339 | if self.record_frames: 340 | img_array = [] 341 | dir_path = os.path.join(self.logger.dir, 'video') 342 | try: 343 | os.makedirs(dir_path) 344 | except OSError: 345 | pass 346 | txt_path = os.path.join(dir_path, str(seed) + '.txt') 347 | with open(txt_path, 'a+') as f: 348 | f.seek(0) 349 | f.truncate() 350 | 351 | done = False 352 | skill_done = True 353 | traj_len = 0 354 | ep_return = 0 355 | ep_game_return = 0 356 | interactions = 0 357 | utils.global_param.set_value('exp', None) 358 | utils.global_param.set_value('explore_done', False) 359 | while not done and traj_len < self.max_ep_len: 360 | if skill_done: 361 | skill = self.planner(obs) 362 | # print(skill) 363 | self.Executive_net = Executive_net(skill,obs[0],self.agent_view_size) 364 | interactions += 1 365 | 366 | if self.record_frames: 367 | img = env.get_mask_render() 368 | text = str(traj_len) + ' ' + self.Executive_net.current_skill 369 | if skill_done: 370 | text += ' (ask)' 371 | with open(txt_path, 'a+') as f: 372 | f.write('step:' + str(traj_len) + '\n' + self.planner.logging_dialogue + '\n') 373 | 374 | cv2.putText(img, 375 | text, 376 | org=(10,300), 377 | fontFace=cv2.FONT_HERSHEY_TRIPLEX, 378 | fontScale=0.6, 379 | color=(255,255,255), 380 | thickness=1, 381 | lineType=cv2.LINE_AA) 382 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 383 | img_array.append(img) 384 | 385 | ## RL choose skill, and return action_list 386 | action, skill_done = self.Executive_net(obs[0]) 387 | ## one step do one action in action_list 388 | obs, reward, done, info = env.step(np.array([action])) 389 | comm_reward = reward - self.ask_lambda * float(skill_done) ## communication penalty 390 | 391 | ep_return += comm_reward 392 | ep_game_return += 1.0*reward 393 | traj_len += 1 394 | if show_dialogue: 395 | if done: 396 | print("RL: Task Completed! \n") 397 | else: 398 | print("RL: Task Fail! \n") 399 | ep_returns += [ep_return] 400 | ep_game_returns += [ep_game_return] 401 | ep_lens += [traj_len] 402 | ep_interactions += [interactions] 403 | 404 | if self.record_frames: 405 | # save vedio 406 | height, width, layers = img.shape 407 | size = (width,height) 408 | video_path = os.path.join(dir_path, str(seed) + '.avi') 409 | out = cv2.VideoWriter(video_path, 410 | fourcc=cv2.VideoWriter_fourcc(*'DIVX'), 411 | fps=3, 412 | frameSize=size) 413 | 414 | for img in img_array: 415 | out.write(img) 416 | out.release() 417 | return np.mean(ep_returns), np.mean(ep_lens), np.mean(ep_interactions), np.mean(ep_game_returns) 418 | 419 | def ask_eval(self, env_fn, trajs=1, seed=None, show_dialogue=False): 420 | env = utils.WrapEnv(env_fn) 421 | with torch.no_grad(): 422 | ep_returns = [] 423 | ep_game_returns = [] 424 | ep_lens = [] 425 | ep_interactions = [] 426 | for traj in range(trajs): 427 | obs = env.reset(seed) 428 | #print(f"[INFO]: Evaluating the task is ", self.task) 429 | self.planner.reset(show_dialogue) 430 | if self.record_frames: 431 | img_array = [] 432 | dir_path = os.path.join(self.logger.dir, 'video') 433 | try: 434 | os.makedirs(dir_path) 435 | except OSError: 436 | pass 437 | txt_path = os.path.join(dir_path, str(seed) + '.txt') 438 | with open(txt_path, 'a+') as f: 439 | f.seek(0) 440 | f.truncate() 441 | 442 | done = False 443 | skill_done = True 444 | traj_len = 0 445 | pre_skill = None 446 | ep_return = 0 447 | ep_game_return = 0 448 | utils.global_param.set_value('exp', None) 449 | utils.global_param.set_value('explore_done', False) 450 | while not done and traj_len < self.max_ep_len: 451 | ## always ask 452 | ask_flag = torch.Tensor([1]) 453 | if skill_done or ask_flag: 454 | skill = self.planner(obs) 455 | # print(skill) 456 | if pre_skill == skill: # additional penalty term for repeat same skill 457 | repeat_feedback = np.array([.5]) 458 | elif pre_skill is None: 459 | repeat_feedback = np.array([0.]) 460 | else: 461 | repeat_feedback = np.array([-0.1]) 462 | pre_skill = skill 463 | self.Executive_net = Executive_net(skill,obs[0],self.agent_view_size) 464 | 465 | if self.record_frames: 466 | img = env.get_mask_render() 467 | text = str(traj_len) + ' ' + self.Executive_net.current_skill 468 | text += ' (ask)' 469 | with open(txt_path, 'a+') as f: 470 | f.write('step:' + str(traj_len) + '\n' + self.planner.logging_dialogue + '\n') 471 | 472 | cv2.putText(img, 473 | text, 474 | org=(10,300), 475 | fontFace=cv2.FONT_HERSHEY_TRIPLEX, 476 | fontScale=0.6, 477 | color=(255,255,255), 478 | thickness=1, 479 | lineType=cv2.LINE_AA) 480 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 481 | img_array.append(img) 482 | 483 | ## RL choose skill, and return action_list 484 | action, skill_done = self.Executive_net(obs[0]) 485 | ## one step do one action in action_list 486 | obs, reward, done, info = env.step(np.array([action])) 487 | comm_penalty = (self.ask_lambda + 0.1 * repeat_feedback) * ask_flag.to("cpu").numpy() ## communication penalty 488 | comm_reward = reward - comm_penalty 489 | #reward = 1.0*reward 490 | ep_return += comm_reward 491 | ep_game_return += 1.0*reward 492 | traj_len += 1 493 | if show_dialogue: 494 | if done: 495 | print("RL: Task Completed! \n") 496 | else: 497 | print("RL: Task Fail! \n") 498 | ep_returns += [ep_return] 499 | ep_game_returns += [ep_game_return] 500 | ep_lens += [traj_len] 501 | ep_interactions += [traj_len] 502 | 503 | if self.record_frames: 504 | # save vedio 505 | height, width, layers = img.shape 506 | size = (width,height) 507 | video_path = os.path.join(dir_path, str(seed) + '.avi') 508 | out = cv2.VideoWriter(video_path, 509 | fourcc=cv2.VideoWriter_fourcc(*'DIVX'), 510 | fps=3, 511 | frameSize=size) 512 | 513 | for img in img_array: 514 | out.write(img) 515 | out.release() 516 | return np.mean(ep_returns), np.mean(ep_lens), np.mean(ep_interactions), np.mean(ep_game_returns) 517 | if __name__ == '__main__': 518 | pass 519 | 520 | -------------------------------------------------------------------------------- /env/Game_RL.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : Game_RL.py 5 | @Time : 2023/06/01 09:17:57 6 | @Author : Hu Bin 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | 11 | from commuication_net import Communication_Net 12 | from executive_net import Executive_net 13 | import os, json, sys 14 | import utils 15 | import gymnasium as gym 16 | import env 17 | import algos 18 | import torch 19 | import numpy as np 20 | import time 21 | import skill 22 | from mediator import SKILL_TO_IDX, SimpleDoorKey_Mediator 23 | from .Game import Game 24 | prefix = os.getcwd() 25 | task_info_json = os.path.join(prefix, "prompt/task_info.json") 26 | class Game_RL(Game): 27 | def __init__(self, args, policy=None): 28 | super().__init__(args, policy) 29 | obs_space, preprocess_obss = utils.get_obss_preprocessor(self.env_fn().observation_space) 30 | self.record_frames = args.record 31 | if policy is None: 32 | ## communication network for ask 33 | """ select skill by RL instand of LLM 34 | Action 0: Explore 35 | Action 1: go to key 36 | Action 2: go to door 37 | Action 3: pickup key 38 | Action 4: toggle door 39 | """ 40 | self.RL_net = Communication_Net(obs_space, 5).to(self.device) 41 | else: 42 | self.RL_net = policy.to(self.device) 43 | 44 | self.ppo_algo = algos.PPO(self.RL_net, device=self.device, save_path=self.logger.dir, batch_size=self.batch_size) 45 | 46 | self.mediator = SimpleDoorKey_Mediator() 47 | 48 | 49 | def flag2skill(self,obs, skill_flag): 50 | 51 | # print(text) 52 | goal = {} 53 | if skill_flag == 0: 54 | goal["action"] = SKILL_TO_IDX["explore"] 55 | goal["object"] = None 56 | elif skill_flag == 1: 57 | agent_map = obs[:, :, 3] 58 | agent_pos = np.argwhere(agent_map != 4)[0] 59 | if 'key' not in self.mediator.obj_coordinate.keys() or obs[:,:,0][agent_pos[0],agent_pos[1]] == 5: 60 | goal["action"] = None 61 | else: 62 | goal["action"] = SKILL_TO_IDX["go to object"] 63 | goal["coordinate"] = self.mediator.obj_coordinate["key"] 64 | elif skill_flag == 2: 65 | if 'door' not in self.mediator.obj_coordinate.keys(): 66 | goal["action"] = None 67 | else: 68 | goal["action"] = SKILL_TO_IDX["go to object"] 69 | goal["coordinate"] = self.mediator.obj_coordinate["door"] 70 | elif skill_flag == 3: 71 | goal["action"] = SKILL_TO_IDX["pickup"] 72 | elif skill_flag == 4: 73 | goal["action"] = SKILL_TO_IDX["toggle"] 74 | return [goal] 75 | 76 | def collect(self, env_fn, seed=None): 77 | # Do one agent-environment interaction 78 | env = utils.WrapEnv(env_fn) 79 | with torch.no_grad(): 80 | buffer = algos.Buffer(self.gamma, self.lam) 81 | obs = env.reset(seed) 82 | self.mediator.reset() 83 | done = False 84 | 85 | traj_len = 0 86 | pre_skill = None 87 | if self.frame_stack >1: 88 | com_obs = obs 89 | else: 90 | his_obs = obs 91 | com_obs = obs - his_obs 92 | 93 | utils.global_param.set_value('exp', None) 94 | utils.global_param.set_value('explore_done', False) 95 | while not done and traj_len < self.max_ep_len: 96 | dist, value = self.RL_net(torch.Tensor(com_obs).to(self.device)) 97 | skill_flag = dist.sample() 98 | log_probs = dist.log_prob(skill_flag) 99 | skill = self.flag2skill(obs[0],skill_flag) 100 | if skill != pre_skill or skill_done: 101 | # print(skill) 102 | self.Executive_net = Executive_net(skill,obs[0],self.agent_view_size) 103 | pre_skill = skill 104 | 105 | ## RL choose skill, and return action_list 106 | action, skill_done = self.Executive_net(obs[0]) 107 | 108 | ## one step do one action in action_list 109 | next_obs, reward, done, info = env.step(np.array([action])) 110 | 111 | buffer.store(com_obs, skill_flag.to("cpu").numpy(), reward, value.to("cpu").numpy(), log_probs.to("cpu").numpy()) 112 | if self.frame_stack >1: 113 | obs = next_obs 114 | com_obs = obs 115 | else: 116 | his_obs = obs 117 | obs = next_obs 118 | com_obs = obs - his_obs 119 | traj_len += 1 120 | _, value = self.RL_net(torch.Tensor(com_obs).to(self.device)) 121 | buffer.finish_path(last_val=(not done) * value.to("cpu").numpy()) 122 | return buffer 123 | 124 | 125 | 126 | 127 | def eval(self, env_fn, trajs=1, seed=None, show_dialogue=False): 128 | env = utils.WrapEnv(env_fn) 129 | with torch.no_grad(): 130 | ep_returns = [] 131 | ep_lens = [] 132 | for traj in range(trajs): 133 | obs = env.reset(seed) 134 | self.mediator.reset() 135 | if self.record_frames: 136 | video_frames = [obs['rgb']] 137 | goal_frames = ["start"] 138 | 139 | done = False 140 | skill_done = True 141 | traj_len = 0 142 | ep_return = 0 143 | pre_skill = None 144 | if self.frame_stack > 1: 145 | com_obs = obs 146 | else: 147 | his_obs = obs 148 | com_obs = obs - his_obs 149 | utils.global_param.set_value('exp', None) 150 | utils.global_param.set_value('explore_done', False) 151 | while not done and traj_len < self.max_ep_len: 152 | skill_flag = self.RL_net.get_action(torch.Tensor(com_obs).to(self.device)) 153 | 154 | skill = self.flag2skill(obs[0],skill_flag) 155 | if skill != pre_skill or skill_done: 156 | self.Executive_net = Executive_net(skill,obs[0],self.agent_view_size) 157 | pre_skill = skill 158 | 159 | ## RL choose skill, and return action_list 160 | action, skill_done = self.Executive_net(obs[0]) 161 | ## one step do one action in action_list 162 | next_obs, reward, done, info = env.step(np.array([action])) 163 | if self.frame_stack >1: 164 | obs = next_obs 165 | com_obs = obs 166 | else: 167 | his_obs = obs 168 | obs = next_obs 169 | com_obs = obs - his_obs 170 | 171 | ep_return += 1.0*reward 172 | traj_len += 1 173 | if show_dialogue: 174 | if done: 175 | print("RL: Task Completed! \n") 176 | else: 177 | print("RL: Task Fail! \n") 178 | ep_returns += [ep_return] 179 | ep_lens += [traj_len] 180 | 181 | return np.mean(ep_returns), np.mean(ep_lens), 0., np.mean(ep_returns) 182 | 183 | 184 | if __name__ == '__main__': 185 | pass 186 | 187 | -------------------------------------------------------------------------------- /env/__init__.py: -------------------------------------------------------------------------------- 1 | from .historicalobs import * 2 | from .singleroom import * 3 | from .boxkey import * 4 | from .randomboxkey import * 5 | from .keydistraction import * 6 | from .Game import * 7 | from .Game_RL import * 8 | gym.envs.register( 9 | id='MiniGrid-SimpleDoorKey-Min5-Max10-View3', 10 | entry_point='env.singleroom:SingleRoomEnv', 11 | kwargs={'minRoomSize' : 5, \ 12 | 'maxRoomSize' : 10, \ 13 | 'agent_view_size' : 3}, 14 | ) 15 | 16 | gym.envs.register( 17 | id='MiniGrid-KeyInBox-Min5-Max10-View3', 18 | entry_point='env.boxkey:BoxKeyEnv', 19 | kwargs={'minRoomSize' : 5, \ 20 | 'maxRoomSize' : 10, \ 21 | 'agent_view_size' : 3}, 22 | ) 23 | 24 | gym.envs.register( 25 | id='MiniGrid-RandomBoxKey-Min5-Max10-View3', 26 | entry_point='env.randomboxkey:RandomBoxKeyEnv', 27 | kwargs={'minRoomSize' : 5, \ 28 | 'maxRoomSize' : 10, \ 29 | 'agent_view_size' : 3}, 30 | ) 31 | 32 | gym.envs.register( 33 | id='MiniGrid-ColoredDoorKey-Min5-Max10-View3', 34 | entry_point='env.keydistraction:KeyDistractionEnv', 35 | kwargs={'minRoomSize' : 5, \ 36 | 'maxRoomSize' : 10, \ 37 | 'minNumKeys' : 2, \ 38 | 'maxNumKeys' : 2, \ 39 | 'agent_view_size' : 3}, 40 | ) -------------------------------------------------------------------------------- /env/boxkey.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from minigrid.core.constants import COLOR_NAMES 3 | from minigrid.core.grid import Grid 4 | from minigrid.core.mission import MissionSpace 5 | from minigrid.core.world_object import Door, Key, Wall, Box 6 | from .historicalobs import HistoricalObsEnv 7 | 8 | 9 | class BoxKeyEnv(HistoricalObsEnv): 10 | def __init__( 11 | self, 12 | minRoomSize: int = 5, 13 | maxRoomSize: int = 10, 14 | agent_view_size: int = 7, 15 | max_steps: int | None = None, 16 | **kwargs, 17 | ): 18 | 19 | self.minRoomSize = minRoomSize 20 | self.maxRoomSize = maxRoomSize 21 | 22 | mission_space = MissionSpace(mission_func=self._gen_mission) 23 | 24 | if max_steps is None: 25 | max_steps = maxRoomSize ** 2 26 | 27 | super().__init__( 28 | mission_space=mission_space, 29 | width=maxRoomSize, 30 | height=maxRoomSize, 31 | agent_view_size=agent_view_size, 32 | max_steps=max_steps, 33 | **kwargs, 34 | ) 35 | 36 | @staticmethod 37 | def _gen_mission(): 38 | return "open the door" 39 | 40 | def _gen_grid(self, width, height): 41 | 42 | # Create the grid 43 | self.grid = Grid(width, height) 44 | 45 | # Choose the room size randomly 46 | sizeX = self._rand_int(self.minRoomSize, self.maxRoomSize + 1) 47 | sizeY = self._rand_int(self.minRoomSize, self.maxRoomSize + 1) 48 | topX, topY = 0, 0 49 | 50 | # Draw the top and bottom walls 51 | wall = Wall() 52 | for i in range(0, sizeX): 53 | self.grid.set(topX + i, topY, wall) 54 | self.grid.set(topX + i, topY + sizeY - 1, wall) 55 | 56 | # Draw the left and right walls 57 | for j in range(0, sizeY): 58 | self.grid.set(topX, topY + j, wall) 59 | self.grid.set(topX + sizeX - 1, topY + j, wall) 60 | 61 | # Pick which wall to place the out door on 62 | wallSet = {0, 1, 2, 3} 63 | exitDoorWall = self._rand_elem(sorted(wallSet)) 64 | 65 | # Pick the exit door position 66 | # Exit on right wall 67 | if exitDoorWall == 0: 68 | exitDoorPos = (topX + sizeX - 1, topY + self._rand_int(1, sizeY - 1)) 69 | # Exit on south wall 70 | elif exitDoorWall == 1: 71 | exitDoorPos = (topX + self._rand_int(1, sizeX - 1), topY + sizeY - 1) 72 | # Exit on left wall 73 | elif exitDoorWall == 2: 74 | exitDoorPos = (topX, topY + self._rand_int(1, sizeY - 1)) 75 | # Exit on north wall 76 | elif exitDoorWall == 3: 77 | exitDoorPos = (topX + self._rand_int(1, sizeX - 1), topY) 78 | else: 79 | assert False 80 | 81 | # Place the door 82 | doorColor = self._rand_elem(sorted(set(COLOR_NAMES))) 83 | exitDoor = Door(doorColor, is_locked=True) 84 | self.door = exitDoor 85 | self.grid.set(exitDoorPos[0], exitDoorPos[1], exitDoor) 86 | 87 | # Randomize the starting agent position and direction 88 | self.place_agent((topX, topY), (sizeX, sizeY), rand_dir=True) 89 | 90 | # Randomize the box and key position 91 | key = Key(doorColor) 92 | box = Box(doorColor, contains=key) 93 | self.place_obj(box, (topX, topY), (sizeX, sizeY)) 94 | 95 | self.mission = "open the door" 96 | 97 | def step(self, action): 98 | obs, reward, terminated, truncated, info = super().step(action) 99 | 100 | if action == self.actions.toggle: 101 | if self.door.is_open: 102 | reward = self._reward() 103 | terminated = True 104 | 105 | return obs, reward, terminated, truncated, info -------------------------------------------------------------------------------- /env/historicalobs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gymnasium import spaces 3 | from minigrid.core.grid import Grid 4 | from minigrid.core.world_object import Wall, WorldObj 5 | from minigrid.minigrid_env import MiniGridEnv 6 | from minigrid.core.mission import MissionSpace 7 | from minigrid.core.constants import TILE_PIXELS 8 | from minigrid.utils.rendering import fill_coords, point_in_rect, point_in_circle 9 | 10 | class Unseen(WorldObj): 11 | def __init__(self): 12 | super().__init__("unseen", "grey") 13 | 14 | def render(self, img): 15 | fill_coords(img, point_in_rect(0, 1, 0, 1), np.array([200, 200, 200])) 16 | # fill_coords(img, point_in_rect(0, 1, 0.45, 0.55), np.array([200, 200, 200])) 17 | 18 | class Footprint(WorldObj): 19 | def __init__(self): 20 | super().__init__("empty", "red") 21 | 22 | def render(self, img): 23 | fill_coords(img, point_in_circle(0.5, 0.5, 0.15), np.array([225, 0, 0])) 24 | 25 | class FlexibleGrid(Grid): 26 | def __init__(self, width: int, height: int): 27 | assert width >= 1 28 | assert height >= 1 29 | 30 | self.width: int = width 31 | self.height: int = height 32 | 33 | self.grid: list[WorldObj | None] = [None] * (width * height) 34 | self.mask: np.ndarray = np.zeros(shape=(self.width, self.height), dtype=bool) 35 | 36 | def looking_vertical(self, i, direction, out_of_bound): 37 | for j in range(0, self.height - 1): # north to south 38 | if not self.mask[i, j]: 39 | continue 40 | 41 | cell = self.get(i, j) 42 | if cell and not cell.see_behind(): 43 | continue 44 | 45 | # print(i, j) 46 | self.mask[i, j + 1] = True 47 | if not out_of_bound: 48 | self.mask[i + direction, j] = True 49 | self.mask[i + direction, j + 1] = True 50 | # print(self.mask.T) 51 | 52 | for j in reversed(range(1, self.height)): # south to north 53 | if not self.mask[i, j]: 54 | continue 55 | 56 | cell = self.get(i, j) 57 | if cell and not cell.see_behind(): 58 | continue 59 | 60 | # print(i, j) 61 | self.mask[i, j - 1] = True 62 | if not out_of_bound: 63 | self.mask[i + direction, j] = True 64 | self.mask[i + direction, j - 1] = True 65 | # print(self.mask.T) 66 | 67 | def looking_horizontal(self, j, direction, out_of_bound): 68 | for i in range(0, self.width - 1): # west to east 69 | if not self.mask[i, j]: 70 | continue 71 | 72 | cell = self.get(i, j) 73 | if cell and not cell.see_behind(): 74 | continue 75 | 76 | # print(i, j) 77 | self.mask[i + 1, j] = True 78 | if not out_of_bound: 79 | self.mask[i, j + direction] = True 80 | self.mask[i + 1, j + direction] = True 81 | # print(self.mask.T) 82 | 83 | for i in reversed(range(1, self.width)): # east to west 84 | if not self.mask[i, j]: 85 | continue 86 | 87 | cell = self.get(i, j) 88 | if cell and not cell.see_behind(): 89 | continue 90 | 91 | # print(i, j) 92 | self.mask[i - 1, j] = True 93 | if not out_of_bound: 94 | self.mask[i, j + direction] = True 95 | self.mask[i - 1, j + direction] = True 96 | # print(self.mask.T) 97 | 98 | def process_vis(self, agent_pos, agent_dir): 99 | self.mask[agent_pos[0], agent_pos[1]] = True 100 | 101 | if agent_dir == 0: # east 102 | for i in range(0, self.width): 103 | out_of_bound = i == self.width - 1 104 | self.looking_vertical(i, 1, out_of_bound) 105 | 106 | elif agent_dir == 2: # west 107 | for i in reversed(range(0, self.width)): 108 | out_of_bound = i == 0 109 | self.looking_vertical(i, -1, out_of_bound) 110 | 111 | elif agent_dir == 1: # south 112 | for j in range(0, self.height): 113 | out_of_bound = j == self.height - 1 114 | self.looking_horizontal(j, 1, out_of_bound) 115 | 116 | elif agent_dir == 3: # north 117 | for j in reversed(range(0, self.height)): 118 | out_of_bound = j == 0 119 | self.looking_horizontal(j, -1, out_of_bound) 120 | return self.mask 121 | 122 | 123 | class HistoricalObsEnv(MiniGridEnv): 124 | def __init__( 125 | self, 126 | mission_space: MissionSpace, 127 | width: int | None = None, 128 | height: int | None = None, 129 | max_steps: int = 100, 130 | **kwargs, 131 | ): 132 | super().__init__( 133 | mission_space=mission_space, 134 | width=width, 135 | height=height, 136 | max_steps=max_steps, 137 | **kwargs, 138 | ) 139 | 140 | image_observation_space = spaces.Box( 141 | low=0, 142 | high=255, 143 | shape=(self.width, self.height, 4), 144 | dtype="uint8", 145 | ) 146 | mission_space = self.observation_space["mission"] 147 | self.observation_space = spaces.Dict( 148 | { 149 | "image": image_observation_space, 150 | "mission": mission_space, 151 | } 152 | ) 153 | 154 | self.mask = np.zeros(shape=(self.width, self.height), dtype=bool) 155 | 156 | def reset(self, seed=None): 157 | obs, info = super().reset(seed=seed) 158 | self.mask = np.zeros(shape=(self.width, self.height), dtype=bool) 159 | topX, topY, botX, botY = self.get_view_exts(agent_view_size=None, clip=True) 160 | vis_grid = self.slice_grid(topX, topY, botX - topX, botY - topY) 161 | if not self.see_through_walls: 162 | vis_mask = vis_grid.process_vis((self.agent_pos[0] - topX, self.agent_pos[1] - topY), self.agent_dir) 163 | else: 164 | vis_mask = np.ones(shape=(botX - topX, botY - topY), dtype=bool) 165 | 166 | self.mask[topX:botX, topY:botY] += vis_mask 167 | obs = self.gen_obs() 168 | return obs, info 169 | 170 | def slice_grid(self, topX, topY, width, height) -> FlexibleGrid: 171 | """ 172 | Get a subset of the grid 173 | """ 174 | 175 | vis_grid = FlexibleGrid(width, height) 176 | 177 | for j in range(0, height): 178 | for i in range(0, width): 179 | x = topX + i 180 | y = topY + j 181 | 182 | if 0 <= x < self.grid.width and 0 <= y < self.grid.height: 183 | v = self.grid.get(x, y) 184 | else: 185 | v = Wall() 186 | 187 | vis_grid.set(i, j, v) 188 | 189 | return vis_grid 190 | 191 | def get_view_exts(self, agent_view_size=None, clip=False): 192 | """ 193 | Get the extents of the square set of tiles visible to the agent 194 | if agent_view_size is None, use self.agent_view_size 195 | """ 196 | 197 | topX, topY, botX, botY = super().get_view_exts(agent_view_size) 198 | if clip: 199 | topX = max(0, topX) 200 | topY = max(0, topY) 201 | botX = min(botX, self.width) 202 | botY = min(botY, self.height) 203 | return topX, topY, botX, botY 204 | 205 | def gen_hist_obs_grid(self, agent_view_size=None): 206 | topX, topY, botX, botY = self.get_view_exts(agent_view_size, clip=True) 207 | grid = self.grid.copy() 208 | vis_grid = self.slice_grid(topX, topY, botX - topX, botY - topY) 209 | if not self.see_through_walls: 210 | vis_mask = vis_grid.process_vis((self.agent_pos[0] - topX, self.agent_pos[1] - topY), self.agent_dir) 211 | else: 212 | vis_mask = np.ones(shape=(botX - topX, botY - topY), dtype=bool) 213 | 214 | self.mask[topX:botX, topY:botY] += vis_mask 215 | 216 | # Make it so the agent sees what it's carrying 217 | if self.carrying: 218 | grid.set(*self.agent_pos, self.carrying) 219 | else: 220 | grid.set(*self.agent_pos, None) 221 | 222 | return grid 223 | 224 | def gen_obs(self): 225 | grid = self.gen_hist_obs_grid() 226 | 227 | image = grid.encode(self.mask) 228 | 229 | agent_pos_dir = np.zeros((self.width, self.height), dtype="uint8") + 4 230 | agent_pos_dir[self.agent_pos] = self.agent_dir 231 | 232 | obs = {"image": np.concatenate((image, agent_pos_dir[:,:,None]), axis=2), "mission": self.mission} 233 | return obs 234 | 235 | def get_full_render(self, highlight, tile_size): 236 | grid = self.gen_hist_obs_grid() 237 | img = grid.render( 238 | tile_size, 239 | agent_pos=self.agent_pos, 240 | agent_dir=self.agent_dir, 241 | highlight_mask=self.mask, 242 | ) 243 | return img 244 | 245 | def get_mask_render(self, path_mask=None, tile_size=TILE_PIXELS): 246 | grid = self.gen_hist_obs_grid() 247 | unseen_mask = np.ones(shape=(self.width, self.height), dtype=bool) ^ self.mask 248 | 249 | if path_mask is None: 250 | path_mask = np.zeros(shape=(self.width, self.height), dtype=bool) 251 | 252 | # Compute the total grid size 253 | width_px = self.width * tile_size 254 | height_px = self.height * tile_size 255 | 256 | img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8) 257 | 258 | # Render the grid 259 | for j in range(0, self.height): 260 | for i in range(0, self.width): 261 | if unseen_mask[i, j]: 262 | cell = Unseen() 263 | elif path_mask[i, j]: 264 | cell = Footprint() 265 | else: 266 | cell = grid.get(i, j) 267 | 268 | agent_here = np.array_equal(self.agent_pos, (i, j)) 269 | tile_img = Grid.render_tile( 270 | cell, 271 | agent_dir=self.agent_dir if agent_here else None, 272 | highlight=False, 273 | tile_size=tile_size, 274 | ) 275 | 276 | ymin = j * tile_size 277 | ymax = (j + 1) * tile_size 278 | xmin = i * tile_size 279 | xmax = (i + 1) * tile_size 280 | img[ymin:ymax, xmin:xmax, :] = tile_img 281 | 282 | return img 283 | -------------------------------------------------------------------------------- /env/keydistraction.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from minigrid.core.constants import COLOR_NAMES 3 | from minigrid.core.grid import Grid 4 | from minigrid.core.mission import MissionSpace 5 | from minigrid.core.world_object import Door, Key, Wall 6 | from .historicalobs import HistoricalObsEnv 7 | 8 | 9 | class KeyDistractionEnv(HistoricalObsEnv): 10 | def __init__( 11 | self, 12 | minRoomSize: int = 5, 13 | maxRoomSize: int = 10, 14 | minNumKeys: int = 2, 15 | maxNumKeys: int = 3, 16 | agent_view_size: int = 7, 17 | max_steps: int | None = None, 18 | **kwargs, 19 | ): 20 | 21 | assert minNumKeys > 0 22 | self.minRoomSize = minRoomSize 23 | self.maxRoomSize = maxRoomSize 24 | self.minNumKeys = minNumKeys 25 | self.maxNumKeys = maxNumKeys 26 | 27 | mission_space = MissionSpace(mission_func=self._gen_mission) 28 | 29 | if max_steps is None: 30 | max_steps = maxRoomSize ** 2 31 | 32 | super().__init__( 33 | mission_space=mission_space, 34 | width=maxRoomSize, 35 | height=maxRoomSize, 36 | agent_view_size=agent_view_size, 37 | max_steps=max_steps, 38 | **kwargs, 39 | ) 40 | 41 | @staticmethod 42 | def _gen_mission(): 43 | return "open the door" 44 | 45 | def _gen_grid(self, width, height): 46 | 47 | # Create the grid 48 | self.grid = Grid(width, height) 49 | 50 | # Choose the room size randomly 51 | sizeX = self._rand_int(self.minRoomSize, self.maxRoomSize + 1) 52 | sizeY = self._rand_int(self.minRoomSize, self.maxRoomSize + 1) 53 | topX, topY = 0, 0 54 | 55 | # Draw the top and bottom walls 56 | wall = Wall() 57 | for i in range(0, sizeX): 58 | self.grid.set(topX + i, topY, wall) 59 | self.grid.set(topX + i, topY + sizeY - 1, wall) 60 | 61 | # Draw the left and right walls 62 | for j in range(0, sizeY): 63 | self.grid.set(topX, topY + j, wall) 64 | self.grid.set(topX + sizeX - 1, topY + j, wall) 65 | 66 | # Pick which wall to place the out door on 67 | wallSet = {0, 1, 2, 3} 68 | exitDoorWall = self._rand_elem(sorted(wallSet)) 69 | 70 | # Pick the exit door position 71 | # Exit on right wall 72 | if exitDoorWall == 0: 73 | rand_int = self._rand_int(1, sizeY - 1) 74 | exitDoorPos = (topX + sizeX - 1, topY + rand_int) 75 | rejectPos = (topX + sizeX - 2, topY + rand_int) 76 | # Exit on south wall 77 | elif exitDoorWall == 1: 78 | rand_int = self._rand_int(1, sizeX - 1) 79 | exitDoorPos = (topX + rand_int, topY + sizeY - 1) 80 | rejectPos = (topX + rand_int, topY + sizeY - 2) 81 | # Exit on left wall 82 | elif exitDoorWall == 2: 83 | rand_int = self._rand_int(1, sizeY - 1) 84 | exitDoorPos = (topX, topY + rand_int) 85 | rejectPos = (topX + 1, topY + rand_int) 86 | # Exit on north wall 87 | elif exitDoorWall == 3: 88 | rand_int = self._rand_int(1, sizeX - 1) 89 | exitDoorPos = (topX + rand_int, topY) 90 | rejectPos = (topX + rand_int, topY + 1) 91 | else: 92 | assert False 93 | 94 | # Place the door 95 | doorColor = self._rand_elem(sorted(set(COLOR_NAMES))) 96 | exitDoor = Door(doorColor, is_locked=True) 97 | self.door = exitDoor 98 | self.grid.set(exitDoorPos[0], exitDoorPos[1], exitDoor) 99 | 100 | # Randomize the starting agent position and direction 101 | self.place_agent((topX, topY), (sizeX, sizeY), rand_dir=True) 102 | 103 | # Randomize the key position 104 | reject_fn = lambda env, pos: pos == rejectPos 105 | numKeys = self._rand_int(self.minNumKeys, self.maxNumKeys + 1) 106 | key = Key(doorColor) 107 | self.place_obj(key, (topX, topY), (sizeX, sizeY), reject_fn = reject_fn) 108 | keyColors = set(COLOR_NAMES) 109 | keyColors.remove(doorColor) 110 | for i in range(numKeys - 1): 111 | keyColor = self._rand_elem(sorted(keyColors)) 112 | key = Key(keyColor) 113 | self.place_obj(key, (topX, topY), (sizeX, sizeY), reject_fn = reject_fn) 114 | 115 | self.mission = "open the door" 116 | 117 | def step(self, action): 118 | obs, reward, terminated, truncated, info = super().step(action) 119 | 120 | if action == self.actions.toggle: 121 | if self.door.is_open: 122 | reward = self._reward() 123 | terminated = True 124 | 125 | return obs, reward, terminated, truncated, info -------------------------------------------------------------------------------- /env/randomboxkey.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from minigrid.core.constants import COLOR_NAMES 3 | from minigrid.core.grid import Grid 4 | from minigrid.core.mission import MissionSpace 5 | from minigrid.core.world_object import Door, Key, Wall, Box 6 | from .historicalobs import HistoricalObsEnv 7 | 8 | 9 | class RandomBoxKeyEnv(HistoricalObsEnv): 10 | def __init__( 11 | self, 12 | minRoomSize: int = 5, 13 | maxRoomSize: int = 10, 14 | agent_view_size: int = 7, 15 | max_steps: int | None = None, 16 | **kwargs, 17 | ): 18 | 19 | self.minRoomSize = minRoomSize 20 | self.maxRoomSize = maxRoomSize 21 | 22 | mission_space = MissionSpace(mission_func=self._gen_mission) 23 | 24 | if max_steps is None: 25 | max_steps = maxRoomSize ** 2 26 | 27 | super().__init__( 28 | mission_space=mission_space, 29 | width=maxRoomSize, 30 | height=maxRoomSize, 31 | agent_view_size=agent_view_size, 32 | max_steps=max_steps, 33 | **kwargs, 34 | ) 35 | 36 | @staticmethod 37 | def _gen_mission(): 38 | return "open the door" 39 | 40 | def _gen_grid(self, width, height): 41 | 42 | # Create the grid 43 | self.grid = Grid(width, height) 44 | 45 | # Choose the room size randomly 46 | sizeX = self._rand_int(self.minRoomSize, self.maxRoomSize + 1) 47 | sizeY = self._rand_int(self.minRoomSize, self.maxRoomSize + 1) 48 | topX, topY = 0, 0 49 | 50 | # Draw the top and bottom walls 51 | wall = Wall() 52 | for i in range(0, sizeX): 53 | self.grid.set(topX + i, topY, wall) 54 | self.grid.set(topX + i, topY + sizeY - 1, wall) 55 | 56 | # Draw the left and right walls 57 | for j in range(0, sizeY): 58 | self.grid.set(topX, topY + j, wall) 59 | self.grid.set(topX + sizeX - 1, topY + j, wall) 60 | 61 | # Pick which wall to place the out door on 62 | wallSet = {0, 1, 2, 3} 63 | exitDoorWall = self._rand_elem(sorted(wallSet)) 64 | 65 | # Pick the exit door position 66 | # Exit on right wall 67 | if exitDoorWall == 0: 68 | rand_int = self._rand_int(1, sizeY - 1) 69 | exitDoorPos = (topX + sizeX - 1, topY + rand_int) 70 | rejectPos = (topX + sizeX - 2, topY + rand_int) 71 | # Exit on south wall 72 | elif exitDoorWall == 1: 73 | rand_int = self._rand_int(1, sizeX - 1) 74 | exitDoorPos = (topX + rand_int, topY + sizeY - 1) 75 | rejectPos = (topX + rand_int, topY + sizeY - 2) 76 | # Exit on left wall 77 | elif exitDoorWall == 2: 78 | rand_int = self._rand_int(1, sizeY - 1) 79 | exitDoorPos = (topX, topY + rand_int) 80 | rejectPos = (topX + 1, topY + rand_int) 81 | # Exit on north wall 82 | elif exitDoorWall == 3: 83 | rand_int = self._rand_int(1, sizeX - 1) 84 | exitDoorPos = (topX + rand_int, topY) 85 | rejectPos = (topX + rand_int, topY + 1) 86 | else: 87 | assert False 88 | 89 | # Place the door 90 | doorColor = self._rand_elem(sorted(set(COLOR_NAMES))) 91 | exitDoor = Door(doorColor, is_locked=True) 92 | self.door = exitDoor 93 | self.grid.set(exitDoorPos[0], exitDoorPos[1], exitDoor) 94 | 95 | # Randomize the starting agent position and direction 96 | self.place_agent((topX, topY), (sizeX, sizeY), rand_dir=True) 97 | 98 | reject_fn = lambda env, pos: pos == rejectPos 99 | KeyinBox = {0, 1} 100 | KeyinBox = self._rand_elem({0, 1}) 101 | if KeyinBox == 0: 102 | # Randomize the box and key position 103 | key = Key(doorColor) 104 | box = Box(doorColor, contains=None) 105 | self.place_obj(key, (topX, topY), (sizeX, sizeY)) 106 | self.place_obj(box, (topX, topY), (sizeX, sizeY), reject_fn = reject_fn) 107 | else: 108 | key = Key(doorColor) 109 | box = Box(doorColor, contains=key) 110 | self.place_obj(box, (topX, topY), (sizeX, sizeY), reject_fn = reject_fn) 111 | 112 | self.mission = "open the door" 113 | 114 | def step(self, action): 115 | obs, reward, terminated, truncated, info = super().step(action) 116 | 117 | if action == self.actions.toggle: 118 | if self.door.is_open: 119 | reward = self._reward() 120 | terminated = True 121 | 122 | return obs, reward, terminated, truncated, info -------------------------------------------------------------------------------- /env/singleroom.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from minigrid.core.constants import COLOR_NAMES 3 | from minigrid.core.grid import Grid 4 | from minigrid.core.mission import MissionSpace 5 | from minigrid.core.world_object import Door, Key, Wall 6 | from .historicalobs import HistoricalObsEnv 7 | 8 | 9 | class SingleRoomEnv(HistoricalObsEnv): 10 | def __init__( 11 | self, 12 | minRoomSize: int = 5, 13 | maxRoomSize: int = 10, 14 | agent_view_size: int = 7, 15 | max_steps: int | None = None, 16 | **kwargs, 17 | ): 18 | 19 | self.minRoomSize = minRoomSize 20 | self.maxRoomSize = maxRoomSize 21 | 22 | mission_space = MissionSpace(mission_func=self._gen_mission) 23 | 24 | if max_steps is None: 25 | max_steps = maxRoomSize ** 2 26 | 27 | super().__init__( 28 | mission_space=mission_space, 29 | width=maxRoomSize, 30 | height=maxRoomSize, 31 | agent_view_size=agent_view_size, 32 | max_steps=max_steps, 33 | **kwargs, 34 | ) 35 | 36 | @staticmethod 37 | def _gen_mission(): 38 | return "open the door" 39 | 40 | def _gen_grid(self, width, height): 41 | 42 | # Create the grid 43 | self.grid = Grid(width, height) 44 | 45 | # Choose the room size randomly 46 | sizeX = self._rand_int(self.minRoomSize, self.maxRoomSize + 1) 47 | sizeY = self._rand_int(self.minRoomSize, self.maxRoomSize + 1) 48 | topX, topY = 0, 0 49 | 50 | # Draw the top and bottom walls 51 | wall = Wall() 52 | for i in range(0, sizeX): 53 | self.grid.set(topX + i, topY, wall) 54 | self.grid.set(topX + i, topY + sizeY - 1, wall) 55 | 56 | # Draw the left and right walls 57 | for j in range(0, sizeY): 58 | self.grid.set(topX, topY + j, wall) 59 | self.grid.set(topX + sizeX - 1, topY + j, wall) 60 | 61 | # Pick which wall to place the out door on 62 | wallSet = {0, 1, 2, 3} 63 | exitDoorWall = self._rand_elem(sorted(wallSet)) 64 | 65 | # Pick the exit door position 66 | # Exit on right wall 67 | if exitDoorWall == 0: 68 | exitDoorPos = (topX + sizeX - 1, topY + self._rand_int(1, sizeY - 1)) 69 | # Exit on south wall 70 | elif exitDoorWall == 1: 71 | exitDoorPos = (topX + self._rand_int(1, sizeX - 1), topY + sizeY - 1) 72 | # Exit on left wall 73 | elif exitDoorWall == 2: 74 | exitDoorPos = (topX, topY + self._rand_int(1, sizeY - 1)) 75 | # Exit on north wall 76 | elif exitDoorWall == 3: 77 | exitDoorPos = (topX + self._rand_int(1, sizeX - 1), topY) 78 | else: 79 | assert False 80 | 81 | # Place the door 82 | doorColor = self._rand_elem(sorted(set(COLOR_NAMES))) 83 | exitDoor = Door(doorColor, is_locked=True) 84 | self.door = exitDoor 85 | self.grid.set(exitDoorPos[0], exitDoorPos[1], exitDoor) 86 | 87 | # Randomize the starting agent position and direction 88 | self.place_agent((topX, topY), (sizeX, sizeY), rand_dir=True) 89 | 90 | # Randomize the key position 91 | key = Key(doorColor) 92 | self.place_obj(key, (topX, topY), (sizeX, sizeY)) 93 | 94 | self.mission = "open the door" 95 | 96 | def step(self, action): 97 | obs, reward, terminated, truncated, info = super().step(action) 98 | 99 | if action == self.actions.toggle: 100 | if self.door.is_open: 101 | reward = self._reward() 102 | terminated = True 103 | 104 | return obs, reward, terminated, truncated, info 105 | 106 | -------------------------------------------------------------------------------- /executive_net.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import gymnasium as gym 5 | import minigrid 6 | from skill import GoTo_Goal, Explore, Pickup, Toggle 7 | from utils import global_param 8 | from mediator import IDX_TO_SKILL, IDX_TO_OBJECT 9 | 10 | class Executive_net(): 11 | def __init__(self, skill_list , init_obs=None, agent_view_size=None): 12 | 13 | assert len(skill_list) > 0 14 | 15 | self.skill_list = skill_list 16 | self.num_of_skills = len(skill_list) 17 | self.current_index = -1 18 | self.agent_view_size = agent_view_size 19 | 20 | self.actor = self.switch_to_next_skill(init_obs) 21 | self.skill_done = False 22 | 23 | # current_skill = skill_list[0] 24 | # self.actor, self.empty_actor = self.switch_skill(current_skill, init_obs) 25 | # self.skill_done = False 26 | 27 | @property 28 | def current_skill(self): 29 | skill = self.skill_list[self.current_index] 30 | return IDX_TO_SKILL[skill['action']] + ' ' + IDX_TO_OBJECT[skill['object']] 31 | 32 | def switch_skill(self, skill, obs): 33 | self.action = skill['action'] 34 | if self.action == 0: 35 | exp = global_param.get_value('exp') 36 | actor = Explore(obs, self.agent_view_size, exp) 37 | global_param.set_value('exp', actor) 38 | elif self.action == 1: 39 | actor = GoTo_Goal(obs, skill['coordinate']) 40 | global_param.set_value('exp', None) 41 | elif self.action == 2: 42 | actor = Pickup(obs) 43 | elif self.action == 4: 44 | actor = Toggle() 45 | else: 46 | actor = None 47 | if actor is None or actor.done_check(): 48 | return None 49 | return actor #, actor.done_check() 50 | 51 | def switch_to_next_skill(self, obs): 52 | # not_valid_actor = True 53 | actor = None 54 | while actor is None: 55 | self.current_index += 1 56 | if self.current_index >= self.num_of_skills: 57 | return None # return None when no skill left in list. 58 | next_skill = self.skill_list[self.current_index] 59 | actor = self.switch_skill(next_skill, obs) 60 | return actor 61 | 62 | def __call__(self, obs): 63 | 64 | if self.actor is None: 65 | return np.array([6]), True 66 | 67 | if self.skill_done: 68 | self.actor = self.switch_to_next_skill(obs) 69 | self.skill_done = False 70 | 71 | # obs = obs if self.action == 0 else None 72 | if self.actor is None: 73 | return np.array([6]), True 74 | 75 | action, done = self.actor.step(obs) 76 | 77 | if done and self.current_index == self.num_of_skills - 1: 78 | return action, True 79 | elif done: 80 | self.skill_done = True 81 | return action, False 82 | else: 83 | return action, False 84 | 85 | 86 | -------------------------------------------------------------------------------- /img/always.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJLAB-AMMI/LLM4RL/efaf4ee669296f4e3939c4cb3493a39b9458182b/img/always.gif -------------------------------------------------------------------------------- /img/baseline.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJLAB-AMMI/LLM4RL/efaf4ee669296f4e3939c4cb3493a39b9458182b/img/baseline.gif -------------------------------------------------------------------------------- /img/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJLAB-AMMI/LLM4RL/efaf4ee669296f4e3939c4cb3493a39b9458182b/img/framework.png -------------------------------------------------------------------------------- /img/hard_code.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJLAB-AMMI/LLM4RL/efaf4ee669296f4e3939c4cb3493a39b9458182b/img/hard_code.gif -------------------------------------------------------------------------------- /img/ours.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJLAB-AMMI/LLM4RL/efaf4ee669296f4e3939c4cb3493a39b9458182b/img/ours.gif -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os,json, sys 3 | import numpy as np 4 | # single gpu 5 | 6 | os.system('nvidia-smi -q -d Memory | grep -A5 GPU | grep Free > tmp.txt') 7 | memory_gpu = [int(x.split()[2]) for x in open('tmp.txt', 'r').readlines()] 8 | os.environ["CUDA_VISIBLE_DEVICES"] = str(np.argmax(memory_gpu)) 9 | os.system('rm tmp.txt') 10 | 11 | import torch 12 | import utils 13 | 14 | def setup_seed(seed): 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | np.random.seed(seed) 18 | torch.backends.cudnn.deterministic = True 19 | 20 | 21 | if __name__ == "__main__": 22 | utils.print_logo(subtitle="Maintained by Research Center for Applied Mathematics and Machine Intelligence, Zhejiang Lab") 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--task", type=str, default="SimpleDoorKey", help="SimpleDoorKey, KeyInBox, RandomBoxKey, ColoredDoorKey") 25 | parser.add_argument("--save_name", type=str, required=True, help="path to folder containing policy and run details") 26 | parser.add_argument("--logdir", type=str, default="./log/") # Where to log diagnostics to 27 | parser.add_argument("--record", default=False, action='store_true') 28 | parser.add_argument("--seed", default=None) 29 | parser.add_argument("--ask_lambda", type=float, default=0.01, help="weight on communication penalty term") 30 | parser.add_argument("--batch_size", type=int, default=32) 31 | parser.add_argument("--lam", type=float, default=0.95, help="Generalized advantage estimate discount") 32 | parser.add_argument("--gamma", type=float, default=0.99, help="MDP discount") 33 | parser.add_argument("--n_itr", type=int, default=1000, help="Number of iterations of the learning algorithm") 34 | parser.add_argument("--policy", type=str, default='ppo') 35 | parser.add_argument( 36 | "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" 37 | ) 38 | parser.add_argument("--traj_per_itr", type=int, default=10) 39 | parser.add_argument("--show", default=False, action='store_true') 40 | parser.add_argument("--test_num", type=int, default=100) 41 | 42 | parser.add_argument("--frame_stack", type=int, default=1) 43 | parser.add_argument("--run_seed_list", type=int, nargs="*", default=[0]) 44 | 45 | 46 | 47 | if sys.argv[1] == 'eval': 48 | sys.argv.remove(sys.argv[1]) 49 | args = parser.parse_args() 50 | 51 | output_dir = os.path.join(args.logdir, args.policy, args.task, args.save_name) 52 | 53 | policy = torch.load(output_dir + "/acmodel.pt") 54 | policy.eval() 55 | eval = utils.Eval(args,policy) 56 | eval.eval_policy(args.test_num) 57 | elif sys.argv[1] == 'eval_RL': 58 | sys.argv.remove(sys.argv[1]) 59 | args = parser.parse_args() 60 | output_dir = os.path.join(args.logdir, args.policy, args.task, args.save_name) 61 | 62 | policy = torch.load(output_dir + "/acmodel.pt") 63 | policy.eval() 64 | eval = utils.Eval(args,policy) 65 | eval.eval_RL_policy(args.test_num) 66 | 67 | elif sys.argv[1] == 'train': 68 | sys.argv.remove(sys.argv[1]) 69 | args = parser.parse_args() 70 | from env.Game import Game 71 | 72 | for i in args.run_seed_list: 73 | setup_seed(i) 74 | args.save_name = args.save_name + str(i) 75 | game = Game(args, run_seed=i) 76 | game.reset() 77 | game.train() 78 | elif sys.argv[1] == 'train_RL': 79 | sys.argv.remove(sys.argv[1]) 80 | args = parser.parse_args() 81 | from env.Game_RL import Game_RL 82 | game = Game_RL(args) 83 | game.reset() 84 | game.train() 85 | elif sys.argv[1] == 'baseline': 86 | sys.argv.remove(sys.argv[1]) 87 | args = parser.parse_args() 88 | eval = utils.Eval(args) 89 | eval.eval_baseline(args.test_num) 90 | elif sys.argv[1] == 'random': 91 | sys.argv.remove(sys.argv[1]) 92 | args = parser.parse_args() 93 | eval = utils.Eval(args) 94 | eval.eval_policy(args.test_num) 95 | elif sys.argv[1] == 'always': 96 | sys.argv.remove(sys.argv[1]) 97 | args = parser.parse_args() 98 | eval = utils.Eval(args) 99 | eval.eval_always_ask(args.test_num) 100 | else: 101 | print("Invalid option '{}'".format(sys.argv[1])) 102 | 103 | -------------------------------------------------------------------------------- /mediator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : mediator.py 5 | @Time : 2023/05/16 10:22:36 6 | @Author : Hu Bin 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | 11 | import numpy as np 12 | import re 13 | import copy 14 | from abc import ABC, abstractmethod 15 | @staticmethod 16 | def get_minigrid_words(): 17 | colors = ["red", "green", "blue", "yellow", "purple", "grey"] 18 | objects = [ 19 | "unseen", 20 | "empty", 21 | "wall", 22 | "floor", 23 | "box", 24 | "key", 25 | "ball", 26 | "door", 27 | "goal", 28 | "agent", 29 | "lava", 30 | ] 31 | 32 | verbs = [ 33 | "pick", 34 | "avoid", 35 | "get", 36 | "find", 37 | "put", 38 | "use", 39 | "open", 40 | "go", 41 | "fetch", 42 | "reach", 43 | "unlock", 44 | "traverse", 45 | ] 46 | 47 | extra_words = [ 48 | "up", 49 | "the", 50 | "a", 51 | "at", 52 | ",", 53 | "square", 54 | "and", 55 | "then", 56 | "to", 57 | "of", 58 | "rooms", 59 | "near", 60 | "opening", 61 | "must", 62 | "you", 63 | "matching", 64 | "end", 65 | "hallway", 66 | "object", 67 | "from", 68 | "room", 69 | ] 70 | 71 | all_words = colors + objects + verbs + extra_words 72 | assert len(all_words) == len(set(all_words)) 73 | return {word: i for i, word in enumerate(all_words)} 74 | 75 | # Map of agent direction, 0: East; 1: South; 2: West; 3: North 76 | DIRECTION = { 77 | 0: [1, 0], 78 | 1: [0, 1], 79 | 2: [-1, 0], 80 | 3: [0, -1], 81 | } 82 | 83 | # Map of object type to integers 84 | OBJECT_TO_IDX = { 85 | "unseen": 0, 86 | "empty": 1, 87 | "wall": 2, 88 | "floor": 3, 89 | "door": 4, 90 | "key": 5, 91 | "ball": 6, 92 | "box": 7, 93 | "goal": 8, 94 | "lava": 9, 95 | "agent": 10, 96 | } 97 | IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys())) 98 | 99 | # Used to map colors to integers 100 | COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5} 101 | IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys())) 102 | 103 | # Map of state names to integers 104 | STATE_TO_IDX = { 105 | "open": 0, 106 | "closed": 1, 107 | "locked": 2, 108 | } 109 | IDX_TO_STATE = dict(zip(STATE_TO_IDX.values(), STATE_TO_IDX.keys())) 110 | 111 | # Map of skill names to integers 112 | SKILL_TO_IDX = {"explore": 0, "go to object": 1, "pickup": 2, "drop": 3, "toggle": 4} 113 | IDX_TO_SKILL = dict(zip(SKILL_TO_IDX.values(), SKILL_TO_IDX.keys())) 114 | 115 | 116 | 117 | 118 | class Base_Mediator(ABC): 119 | """The base class for Base_Mediator.""" 120 | 121 | def __init__(self): 122 | super().__init__() 123 | self.obj_coordinate = {} 124 | 125 | @abstractmethod 126 | def RL2LLM(self): 127 | pass 128 | @abstractmethod 129 | def LLM2RL(self): 130 | pass 131 | 132 | def reset(self): 133 | self.obj_coordinate = {} 134 | 135 | class SimpleDoorKey_Mediator(Base_Mediator): 136 | def __init__(self, ): 137 | super().__init__() 138 | 139 | # ## obs2text 140 | def RL2LLM(self, obs): 141 | context = '' 142 | if len(obs.shape) == 4: 143 | obs = obs[0,:,:,-4:] 144 | obs_object = copy.deepcopy(obs[:,:,0]) 145 | agent_map = obs[:, :, 3] 146 | agent_pos = np.argwhere(agent_map != 4)[0] 147 | 148 | key_list = np.argwhere(obs_object==5) 149 | door_list = np.argwhere(obs_object==4) 150 | carry_object = None 151 | if len(key_list): 152 | for key in key_list: 153 | i, j = key 154 | color = obs[i,j,1] 155 | #object = f"{IDX_TO_COLOR[color]} key" 156 | object = "key" 157 | if (agent_pos == key).all(): 158 | carry_object = object 159 | else: 160 | if context != "": 161 | context += ", " 162 | context += f"observed a {object}" 163 | self.obj_coordinate[object] = (i,j) 164 | 165 | if len(door_list): 166 | for door in door_list: 167 | i, j = door 168 | color = obs[i,j,1] 169 | #object = f"{IDX_TO_COLOR[color]} door" 170 | object = "door" 171 | if context != "": 172 | context += ", " 173 | context += f"observed a {object}" 174 | self.obj_coordinate[object] = (i,j) 175 | 176 | if context == '': 177 | context += "observed {}".format('nothing') 178 | if carry_object is not None: 179 | if context != "": 180 | context += ", " 181 | context += f"carry {carry_object}" 182 | context = f"observation: {{{context}}} " 183 | return context 184 | 185 | def parser(self, text): 186 | # action: 187 | if "explore" in text: 188 | act = SKILL_TO_IDX["explore"] 189 | elif "go to" in text: 190 | act = SKILL_TO_IDX["go to object"] 191 | elif "pick up" in text: 192 | act = SKILL_TO_IDX["pickup"] 193 | elif "drop" in text: 194 | act = SKILL_TO_IDX["drop"] 195 | elif "toggle" in text: 196 | act = SKILL_TO_IDX["toggle"] 197 | else: 198 | print("Unknown Planning :", text) 199 | act = 6 # do nothing 200 | # object: 201 | try: 202 | if "key" in text: 203 | obj = OBJECT_TO_IDX["key"] 204 | coordinate = self.obj_coordinate["key"] 205 | elif "door" in text: 206 | obj = OBJECT_TO_IDX["door"] 207 | coordinate = self.obj_coordinate["door"] 208 | else: 209 | obj = OBJECT_TO_IDX["empty"] 210 | coordinate = None 211 | except: 212 | print("Unknown Planning :", text) 213 | act = 6 # do nothing 214 | return act, obj, coordinate 215 | 216 | 217 | def LLM2RL(self, plan): 218 | 219 | plan = re.findall(r'{(.*?)}', plan) 220 | lines = plan[0].split(',') 221 | skill_list = [] 222 | for line in lines: 223 | 224 | action, object, coordinate = self.parser(line) 225 | goal = {} 226 | goal["action"] = action 227 | goal["object"] = object 228 | goal["coordinate"] = coordinate 229 | skill_list.append(goal) 230 | return skill_list 231 | 232 | class KeyInBox_Mediator(Base_Mediator): 233 | def __init__(self): 234 | super().__init__() 235 | 236 | # ## obs2text 237 | def RL2LLM(self, obs): 238 | context = '' 239 | if len(obs.shape) == 4: 240 | obs = obs[0,:,:,-4:] 241 | obs_object = copy.deepcopy(obs[:,:,0]) 242 | agent_map = obs[:, :, 3] 243 | agent_pos = np.argwhere(agent_map != 4)[0] 244 | 245 | key_list = np.argwhere(obs_object==5) 246 | door_list = np.argwhere(obs_object==4) 247 | box_list = np.argwhere(obs_object==7) 248 | carry_object = None 249 | if len(key_list): 250 | for key in key_list: 251 | i, j = key 252 | object = "key" 253 | if (agent_pos == key).all(): 254 | carry_object = object 255 | else: 256 | if context != "": 257 | context += ", " 258 | context += f"observed a {object}" 259 | self.obj_coordinate[object] = (i,j) 260 | if len(box_list): 261 | for box in box_list: 262 | i, j = box 263 | object = "box" 264 | if context != "": 265 | context += ", " 266 | context += f"observed a {object}" 267 | self.obj_coordinate[object] = (i,j) 268 | if len(door_list): 269 | for door in door_list: 270 | i, j = door 271 | object = "door" 272 | if context != "": 273 | context += ", " 274 | context += f"observed a {object}" 275 | self.obj_coordinate[object] = (i,j) 276 | 277 | if context == '': 278 | context += "observed {}".format('nothing') 279 | context = f"observation: {{{context}}}, " 280 | if carry_object is not None: 281 | if context != "": 282 | context += ", " 283 | context += f"carry {carry_object}" 284 | return context 285 | 286 | def parser(self, text): 287 | # action: 288 | if "explore" in text: 289 | act = SKILL_TO_IDX["explore"] 290 | elif "go to" in text: 291 | act = SKILL_TO_IDX["go to object"] 292 | elif "pick up" in text: 293 | act = SKILL_TO_IDX["pickup"] 294 | elif "drop" in text: 295 | act = SKILL_TO_IDX["drop"] 296 | elif "toggle" in text: 297 | act = SKILL_TO_IDX["toggle"] 298 | else: 299 | print("Unknown Planning :", text) 300 | act = 6 # do nothing 301 | # object: 302 | try: 303 | if "key" in text: 304 | obj = OBJECT_TO_IDX["key"] 305 | if "key" in self.obj_coordinate.keys(): 306 | coordinate = self.obj_coordinate["key"] 307 | else: 308 | coordinate = self.obj_coordinate["box"] 309 | elif "door" in text: 310 | obj = OBJECT_TO_IDX["door"] 311 | coordinate = self.obj_coordinate["door"] 312 | elif "box" in text: 313 | obj = OBJECT_TO_IDX["box"] 314 | coordinate = self.obj_coordinate["box"] 315 | else: 316 | obj = OBJECT_TO_IDX["empty"] 317 | coordinate = None 318 | except: 319 | print("Unknown Planning :", text) 320 | act = 6 # do nothing 321 | return act, obj, coordinate 322 | 323 | 324 | def LLM2RL(self, plan): 325 | plan = re.findall(r'{(.*?)}', plan) 326 | lines = plan[0].split(',') 327 | skill_list = [] 328 | for line in lines: 329 | 330 | action, object, coordinate = self.parser(line) 331 | goal = {} 332 | goal["action"] = action 333 | goal["object"] = object 334 | goal["coordinate"] = coordinate 335 | skill_list.append(goal) 336 | return skill_list 337 | 338 | class RandomBoxKey_Mediator(Base_Mediator): 339 | def __init__(self): 340 | super().__init__() 341 | 342 | # ## obs2text 343 | def RL2LLM(self, obs): 344 | context = '' 345 | if len(obs.shape) == 4: 346 | obs = obs[0,:,:,-4:] 347 | obs_object = copy.deepcopy(obs[:,:,0]) 348 | agent_map = obs[:, :, 3] 349 | agent_pos = np.argwhere(agent_map != 4)[0] 350 | 351 | key_list = np.argwhere(obs_object==5) 352 | door_list = np.argwhere(obs_object==4) 353 | box_list = np.argwhere(obs_object==7) 354 | carry_object = None 355 | if len(key_list): 356 | for key in key_list: 357 | i, j = key 358 | object = "key" 359 | if (agent_pos == key).all(): 360 | carry_object = object 361 | else: 362 | if context != "": 363 | context += ", " 364 | context += f"observed a {object}" 365 | self.obj_coordinate[object] = (i,j) 366 | if len(box_list): 367 | for box in box_list: 368 | i, j = box 369 | object = "box" 370 | if context != "": 371 | context += ", " 372 | context += f"observed a {object}" 373 | self.obj_coordinate[object] = (i,j) 374 | if len(door_list): 375 | for door in door_list: 376 | i, j = door 377 | object = "door" 378 | if context != "": 379 | context += ", " 380 | context += f"observed a {object}" 381 | self.obj_coordinate[object] = (i,j) 382 | 383 | if context == '': 384 | context += "observed {}".format('nothing') 385 | if carry_object is not None: 386 | if context != "": 387 | context += ", " 388 | context += f"carry {carry_object}" 389 | context = f"observation: {{{context}}} " 390 | return context 391 | 392 | def parser(self, text): 393 | # action: 394 | if "explore" in text: 395 | act = SKILL_TO_IDX["explore"] 396 | elif "go to" in text: 397 | act = SKILL_TO_IDX["go to object"] 398 | elif "pick up" in text: 399 | act = SKILL_TO_IDX["pickup"] 400 | elif "drop" in text: 401 | act = SKILL_TO_IDX["drop"] 402 | elif "toggle" in text: 403 | act = SKILL_TO_IDX["toggle"] 404 | else: 405 | print("Unknown Planning :", text) 406 | act = 6 # do nothing 407 | # object: 408 | try: 409 | if "key" in text: 410 | obj = OBJECT_TO_IDX["key"] 411 | if "key" in self.obj_coordinate.keys(): 412 | coordinate = self.obj_coordinate["key"] 413 | else: 414 | coordinate = self.obj_coordinate["box"] 415 | elif "door" in text: 416 | obj = OBJECT_TO_IDX["door"] 417 | coordinate = self.obj_coordinate["door"] 418 | elif "box" in text: 419 | obj = OBJECT_TO_IDX["box"] 420 | coordinate = self.obj_coordinate["box"] 421 | else: 422 | obj = OBJECT_TO_IDX["empty"] 423 | coordinate = None 424 | except: 425 | print("Unknown Planning :", text) 426 | act = 6 # do nothing 427 | return act, obj, coordinate 428 | 429 | 430 | def LLM2RL(self, plan): 431 | plan = re.findall(r'{(.*?)}', plan) 432 | lines = plan[0].split(',') 433 | skill_list = [] 434 | for line in lines: 435 | 436 | action, object, coordinate = self.parser(line) 437 | goal = {} 438 | goal["action"] = action 439 | goal["object"] = object 440 | goal["coordinate"] = coordinate 441 | skill_list.append(goal) 442 | return skill_list 443 | 444 | class ColoredDoorKey_Mediator(Base_Mediator): 445 | def __init__(self): 446 | super().__init__() 447 | 448 | # ## obs2text 449 | def RL2LLM(self, obs): 450 | context = '' 451 | if len(obs.shape) == 4: 452 | obs = obs[0,:,:,-4:] 453 | obs_object = copy.deepcopy(obs[:,:,0]) 454 | agent_map = obs[:, :, 3] 455 | agent_pos = np.argwhere(agent_map != 4)[0] 456 | agent_dir = agent_map[agent_pos[0],agent_pos[1]] 457 | 458 | key_list = np.argwhere(obs_object==5) 459 | door_list = np.argwhere(obs_object==4) 460 | 461 | block_object = None 462 | carry_object = None 463 | if len(key_list): 464 | for key in key_list: 465 | i, j = key 466 | color = obs[i,j,1] 467 | object = f"{IDX_TO_COLOR[color]} key" 468 | 469 | if (agent_pos == key).all(): 470 | carry_object = object 471 | else: 472 | if context != "": 473 | context += ", " 474 | context += f"observed {object}" 475 | self.obj_coordinate[object] = (i,j) 476 | 477 | if (agent_pos + DIRECTION[agent_dir] == key).all(): 478 | block_object = object 479 | if len(door_list): 480 | for door in door_list: 481 | i, j = door 482 | color = obs[i,j,1] 483 | object = f"{IDX_TO_COLOR[color]} door" 484 | if context != "": 485 | context += ", " 486 | context += f"observed {object}" 487 | self.obj_coordinate[object] = (i,j) 488 | 489 | # if block_object is not None: 490 | # if context != "": 491 | # context += ", " 492 | # context += f"block by {block_object}" 493 | 494 | if context == '': 495 | context += "observed {}".format('nothing') 496 | if carry_object is not None: 497 | if context != "": 498 | context += ", " 499 | context += f"carry {carry_object}" 500 | #context = f"observation:{{{context}}}," 501 | context = f"Q: [{context}]" 502 | return context 503 | 504 | def parser(self, text): 505 | # action: 506 | if "explore" in text: 507 | act = SKILL_TO_IDX["explore"] 508 | elif "go to" in text: 509 | act = SKILL_TO_IDX["go to object"] 510 | elif "pick up" in text: 511 | act = SKILL_TO_IDX["pickup"] 512 | elif "drop" in text: 513 | act = SKILL_TO_IDX["drop"] 514 | elif "toggle" in text: 515 | act = SKILL_TO_IDX["toggle"] 516 | else: 517 | print("Unknown Planning :", text) 518 | act = 6 # do nothing 519 | # object: 520 | try: 521 | if "key" in text: 522 | obj = OBJECT_TO_IDX["key"] 523 | words = text.split(' ') 524 | filter_words = [] 525 | for w in words: 526 | w1="".join(c for c in w if c.isalpha()) 527 | filter_words.append(w1) 528 | object_word = filter_words[-2] + " " + filter_words[-1] 529 | coordinate = self.obj_coordinate[object_word] 530 | elif "door" in text: 531 | obj = OBJECT_TO_IDX["door"] 532 | words = text.split(' ') 533 | filter_words = [] 534 | for w in words: 535 | w1="".join(c for c in w if c.isalpha()) 536 | filter_words.append(w1) 537 | object_word = filter_words[-2] + " " + filter_words[-1] 538 | coordinate = self.obj_coordinate[object_word] 539 | else: 540 | obj = OBJECT_TO_IDX["empty"] 541 | coordinate = None 542 | except: 543 | print("Unknown Planning :", text) 544 | coordinate = None 545 | act = 6 # do nothing 546 | return act, obj, coordinate 547 | 548 | 549 | def LLM2RL(self, plan): 550 | plan = re.findall(r'{(.*?)}', plan) 551 | lines = plan[0].split(',') 552 | skill_list = [] 553 | for line in lines: 554 | 555 | action, object, coordinate = self.parser(line) 556 | goal = {} 557 | goal["action"] = action 558 | goal["object"] = object 559 | goal["coordinate"] = coordinate 560 | skill_list.append(goal) 561 | return skill_list 562 | 563 | 564 | if __name__ == "__main__": 565 | word = get_minigrid_words() -------------------------------------------------------------------------------- /planner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : planner.py 5 | @Time : 2023/05/16 09:12:11 6 | @Author : Hu Bin 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | 11 | 12 | import os, requests 13 | from typing import Any 14 | from mediator import * 15 | from utils import global_param 16 | 17 | 18 | from abc import ABC, abstractmethod 19 | 20 | class Base_Planner(ABC): 21 | """The base class for Planner.""" 22 | 23 | def __init__(self): 24 | super().__init__() 25 | self.dialogue_system = '' 26 | self.dialogue_user = '' 27 | self.dialogue_logger = '' 28 | self.show_dialogue = False 29 | self.llm_model = None 30 | self.llm_url = None 31 | def reset(self, show=False): 32 | self.dialogue_user = '' 33 | self.dialogue_logger = '' 34 | self.show_dialogue = show 35 | 36 | ## initial prompt, write in 'prompt/task_info.json 37 | def initial_planning(self, decription, example): 38 | if self.llm_model is None: 39 | assert "no select Large Language Model" 40 | prompts = decription + example 41 | self.dialogue_system += decription + "\n" 42 | self.dialogue_system += example + "\n" 43 | 44 | ## set system part 45 | server_error_cnt = 0 46 | while server_error_cnt<10: 47 | try: 48 | url = self.llm_url 49 | headers = {'Content-Type': 'application/json'} 50 | 51 | data = {'model': self.llm_model, "messages":[{"role": "system", "content": prompts}]} 52 | response = requests.post(url, headers=headers, json=data) 53 | 54 | if response.status_code == 200: 55 | result = response.json() 56 | server_flag = 1 57 | 58 | 59 | if server_flag: 60 | break 61 | 62 | except Exception as e: 63 | server_error_cnt += 1 64 | print(e) 65 | 66 | def query_codex(self, prompt_text): 67 | server_flag = 0 68 | server_error_cnt = 0 69 | response = '' 70 | while server_error_cnt<10: 71 | try: 72 | #response = openai.Completion.create(prompt_text) 73 | url = self.llm_url 74 | headers = {'Content-Type': 'application/json'} 75 | 76 | # prompt_text 77 | 78 | data = {'model': self.llm_model, "messages":[{"role": "user", "content": prompt_text }]} 79 | response = requests.post(url, headers=headers, json=data) 80 | 81 | 82 | 83 | if response.status_code == 200: 84 | result = response.json() 85 | server_flag = 1 86 | 87 | 88 | if server_flag: 89 | break 90 | 91 | except Exception as e: 92 | server_error_cnt += 1 93 | print(e) 94 | if result is None: 95 | return 96 | else: 97 | return result['messages'][-1][-1] 98 | 99 | def check_plan_isValid(self, plan): 100 | if "{" in plan and "}" in plan: 101 | return True 102 | else: 103 | return False 104 | 105 | def step_planning(self, text): 106 | ## seed for LLM and get feedback 107 | plan = self.query_codex(text) 108 | if plan is not None: 109 | ## check Valid, llm may give wrong answer 110 | while not self.check_plan_isValid(plan): 111 | print("%s is illegal Plan! Replan ...\n" %plan) 112 | plan = self.query_codex(text) 113 | return plan 114 | 115 | @abstractmethod 116 | def forward(self): 117 | pass 118 | 119 | class SimpleDoorKey_Planner(Base_Planner): 120 | def __init__(self, seed=0): 121 | super().__init__() 122 | 123 | self.mediator = SimpleDoorKey_Mediator() 124 | if seed %2 ==0: 125 | self.llm_model = "vicuna-7b-3" 126 | self.llm_url = 'http://10.106.27.11:8003/v1/chat/completions' 127 | else: 128 | self.llm_model = "vicuna-7b-0" 129 | self.llm_url = 'http://10.106.27.11:8000/v1/chat/completions' 130 | 131 | def __call__(self, input): 132 | return self.forward(input) 133 | 134 | def reset(self, show=False): 135 | self.dialogue_user = '' 136 | self.dialogue_logger = '' 137 | self.show_dialogue = show 138 | ## reset dialogue 139 | if self.show_dialogue: 140 | print(self.dialogue_system) 141 | 142 | 143 | def forward(self, obs): 144 | text = self.mediator.RL2LLM(obs) 145 | # print(text) 146 | plan = self.step_planning(text) 147 | 148 | self.dialogue_logger += text 149 | self.dialogue_logger += plan 150 | self.dialogue_user = text +"\n" 151 | self.dialogue_user += plan 152 | if self.show_dialogue: 153 | print(self.dialogue_user) 154 | skill = self.mediator.LLM2RL(plan) 155 | return skill 156 | 157 | 158 | class KeyInBox_Planner(Base_Planner): 159 | def __init__(self,seed=0): 160 | super().__init__() 161 | self.mediator = KeyInBox_Mediator() 162 | if seed %2 == 0: 163 | self.llm_model = "vicuna-7b-4" 164 | self.llm_url = 'http://10.109.116.3:8004/v1/chat/completions' 165 | else: 166 | self.llm_model = "vicuna-7b-1" 167 | self.llm_url = 'http://10.109.116.3:8001/v1/chat/completions' 168 | 169 | def __call__(self, input): 170 | return self.forward(input) 171 | 172 | def reset(self, show=False): 173 | self.dialogue_user = '' 174 | self.dialogue_logger = '' 175 | self.show_dialogue = show 176 | ## reset dialogue 177 | if self.show_dialogue: 178 | print(self.dialogue_system) 179 | 180 | 181 | def forward(self, obs): 182 | text = self.mediator.RL2LLM(obs) 183 | # print(text) 184 | plan = self.step_planning(text) 185 | 186 | self.dialogue_logger += text 187 | self.dialogue_logger += plan 188 | self.dialogue_user = text +"\n" 189 | self.dialogue_user += plan 190 | if self.show_dialogue: 191 | print(self.dialogue_user) 192 | skill = self.mediator.LLM2RL(plan) 193 | return skill 194 | 195 | 196 | class RandomBoxKey_Planner(Base_Planner): 197 | def __init__(self, seed=0): 198 | super().__init__() 199 | self.mediator = RandomBoxKey_Mediator() 200 | if seed %2 == 0: 201 | self.llm_model = "vicuna-7b-5" 202 | self.llm_url = 'http://10.109.116.3:8005/v1/chat/completions' 203 | else: 204 | self.llm_model = "vicuna-7b-2" 205 | self.llm_url = 'http://10.109.116.3:8002/v1/chat/completions' 206 | def __call__(self, input): 207 | return self.forward(input) 208 | 209 | def reset(self, show=False): 210 | self.dialogue_user = '' 211 | self.dialogue_logger = '' 212 | self.show_dialogue = show 213 | ## reset dialogue 214 | if self.show_dialogue: 215 | print(self.dialogue_system) 216 | 217 | def forward(self, obs): 218 | text = self.mediator.RL2LLM(obs) 219 | # print(text) 220 | plan = self.step_planning(text) 221 | 222 | self.dialogue_logger += text 223 | self.dialogue_logger += plan 224 | self.dialogue_user = text +"\n" 225 | self.dialogue_user += plan 226 | if self.show_dialogue: 227 | print(self.dialogue_user) 228 | skill = self.mediator.LLM2RL(plan) 229 | return skill 230 | 231 | class ColoredDoorKey_Planner(Base_Planner): 232 | def __init__(self,seed=0): 233 | super().__init__() 234 | self.mediator = ColoredDoorKey_Mediator() 235 | if seed %2 == 0: 236 | self.llm_model = "vicuna-7b-7" 237 | self.llm_url = 'http://10.109.116.3:5678/v1/chat/completions' 238 | else: 239 | self.llm_model = "vicuna-7b-6" 240 | self.llm_url = 'http://10.109.116.3:8006/v1/chat/completions' 241 | 242 | 243 | def __call__(self, input): 244 | return self.forward(input) 245 | 246 | def reset(self, show=False): 247 | self.dialogue_user = '' 248 | self.dialogue_logger = '' 249 | self.show_dialogue = show 250 | ## reset dialogue 251 | if self.show_dialogue: 252 | print(self.dialogue_system) 253 | 254 | 255 | def forward(self, obs): 256 | text = self.mediator.RL2LLM(obs) 257 | # print(text) 258 | plan = self.step_planning(text) 259 | 260 | self.dialogue_logger += text 261 | self.dialogue_logger += plan 262 | self.dialogue_user = text +"\n" 263 | self.dialogue_user += plan 264 | if self.show_dialogue: 265 | print(self.dialogue_user) 266 | skill = self.mediator.LLM2RL(plan) 267 | return skill 268 | 269 | 270 | def Planner(task,seed=0): 271 | if task.lower() == "simpledoorkey": 272 | planner = SimpleDoorKey_Planner(seed) 273 | elif task.lower() == "keyinbox": 274 | planner = KeyInBox_Planner(seed) 275 | elif task.lower() == "randomboxkey": 276 | planner = RandomBoxKey_Planner(seed) 277 | elif task.lower() == "coloreddoorkey": 278 | planner = ColoredDoorKey_Planner(seed) 279 | return planner 280 | -------------------------------------------------------------------------------- /prompt/task_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "simpledoorkey":{ 3 | "episode": 100, 4 | "level": "easy", 5 | "description": "an agent in a minigrid environment in reinfrocement learning, the task of the agent is toggle the door in the maze with key. please help agent to plan the next action given agent's current observations and statu: carry {object} or none. Availabel actions may includes: explore, go to {object}, pick up {object}, toggle {object}. the actions should be displayed in a list. Do not explain the reasoning. \n ", 6 | "example": "observation: {observed nothing}, action: {explore}. \n observation: {observed a door}, action: {explore}. \n observation: {observed a key, observed a door}, action: {go to the key, pick up the key, go to the door, toggle the door}. \n observation: {observed a door, carry key}, action: {go to the door, toggle the door}. \n observation: {observed a key}, action: {go to the key, pick up the key, explore}.", 7 | "configurations": "MiniGrid-SimpleDoorKey-Min5-Max10-View3" 8 | }, 9 | 10 | "keyinbox":{ 11 | "episode": 100, 12 | "level": "easy", 13 | "description": "an agent in a minigrid environment in reinfrocement learning, the task of the agent is to toggle the door in the maze. key is hidden is a box. Please help agent to plan the next actions given observation and statu: carry {object} or none. Availabel actions may includes: explore, go to {object}, pick up {object}, toggle {object}. the actions should be displayed in a list. Do not explain the reasoning. \n ", 14 | "example": "Example: \n observation: {observed a box}, action: {go to the box, toggle the box}. \n observation: {observed nothing}, action: {explore}. \n observation: {observed a door}, action: {explore}. \n observation: {observed a key}, action: {go to the key, pick up the key}. \n observation: {observed a box, observed a door}, action: {go to the box, toggle the box}. \n observation: {observed a door, carry key}, action: {go to the door, toggle the door}. \n observation: {observed a key, observed a door}, action: {go to the key, pick up the key, go to the door, toggle the door}.", 15 | "configurations": "MiniGrid-KeyInBox-Min5-Max10-View3" 16 | }, 17 | 18 | "randomboxkey":{ 19 | "episode": 100, 20 | "level": "hard", 21 | "description": "an agent in a minigrid environment in reinfrocement learning, the task of the agent is to toggle the door in the maze. key may be hidden in a box. Please help agent to plan the next actions given observation and statu: carry {object} or none. Availabel actions may includes: explore, go to {object}, pick up {object}, toggle {object}. the actions should be displayed in a list. Do not explain the reasoning. \n" , 22 | "example": "Example: \n observation: {observed a box, carry key}, action: {explore}. \n observation: {observed a box}, action: {go to the box, toggle the box}. \n observation: {observed a box, observed a door, carry key}, action: {go to the door, toggle the door}. \n observation: {obseved a key, observed a box, observed a door}, action: {go to the key, pick up the key, go to the door, toggle the door}. \n observation: {observed nothing}, action: {explore}. \n observation: {observed a door}, action: {explore}. \n observation: {observed a key}, action: {go to the key, pick up the key}. \n observation: {observed a box, observed a door}, action: {go to the box, toggle the box}. \n observation: {observed a door, carry key}, action: {go to the door, toggle the door}. \n observation: {observed a key, observed a door}, action: {go to the key, pick up the key, go to the door, toggle the door}. \n observation: {observed a key, observed a box}, action: {go to the key, pick up the key, explore}. " , 23 | "configurations": "MiniGrid-RandomBoxKey-Min5-Max10-View3" 24 | }, 25 | 26 | 27 | "coloreddoorkey":{ 28 | "episode": 100, 29 | "level": "hard", 30 | "description": "An agent in a Minigrid environment in reinfrocement learning, the task of the agent is to toggle the color door with same color key. Format answer as following way:\n\n" , 31 | "example": "Q: [observed key, observed key, observed door]\nA: [observed key, observed key, observed door][observed key, observed door]{go to key, pick up key}\n\nQ: [observed key, observed door, carry key]\nA: [observed key, observed door, carry key][observed key, observed door]{go to key, pick up key}\n\nQ: [observed key, observed door, carry key]\nA: [observed key, observed door, carry key][observed door, carry key]{go to door, toggle door}\n\nQ: [observed door]\nA:[observed door]{explore}\n\nQ: [observed key, observed key]\nA: [observed key]{go to key, pick up key}[observed key, carry key]{explore}\n\nQ: [observed door, carry key]\nA: [observed door, carry key]{go to door, toggle door}\n\nQ: [observed key, observed door]\nA: [observed key, observed door]{go to key, pick up key}[observed door, carry key]{go to door, toggle door}\n\nQ: [observed door, carry key]\nA:[observed door, carry key]{explore}\n\nQ: [observed key, observed door]\nA: [observed key, observed door]{go to key, pick up key}[observed door, carry key]{explore}\n\nQ: [observed nothing]\nA: [observed nothing]{explore}\n", 32 | "configurations": "MiniGrid-ColoredDoorKey-Min5-Max10-View3" 33 | } 34 | 35 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gymnasium==0.28.1 2 | minigrid==2.2.1 3 | numpy==1.24.3 4 | Requests==2.31.0 5 | torch==2.0.1 6 | torch_ac==1.4.0 7 | -------------------------------------------------------------------------------- /skill/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_skill import * 2 | from .explore import * 3 | from .goto_goal import * -------------------------------------------------------------------------------- /skill/base_skill.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import gymnasium as gym 5 | import minigrid 6 | from minigrid.core.constants import DIR_TO_VEC 7 | 8 | ''' 9 | OBJECT_TO_IDX = { 10 | "unseen": 0, 11 | "empty": 1, 12 | "wall": 2, 13 | "floor": 3, 14 | "door": 4, 15 | "key": 5, 16 | "ball": 6, 17 | "box": 7, 18 | "goal": 8, 19 | "lava": 9, 20 | "agent": 10, 21 | } 22 | 23 | STATE_TO_IDX = { 24 | "open": 0, 25 | "closed": 1, 26 | "locked": 2, 27 | } 28 | ''' 29 | 30 | DIRECTION = { 31 | 0: [1, 0], 32 | 1: [0, 1], 33 | 2: [-1, 0], 34 | 3: [0, -1], 35 | } 36 | 37 | class BaseSkill(): 38 | def __init__(self): 39 | pass 40 | 41 | def unpack_obs(self, obs): 42 | agent_map = obs[:, :, 3] 43 | self.agent_pos = np.argwhere(agent_map != 4)[0] 44 | self.agent_dir = obs[self.agent_pos[0], self.agent_pos[1], 3] 45 | self.map = obs[:, :, 0] 46 | # print(self.agent_pos, self.agent_dir) 47 | def step(self, obs=None): 48 | raise NotImplementedError 49 | 50 | def done_check(self): 51 | return False 52 | 53 | class Pickup(BaseSkill): 54 | def __init__(self, init_obs): 55 | init_obs = init_obs[:,:,-4:] 56 | self.path_prefix = [] 57 | self.path_suffix = [] 58 | self.plan(init_obs) 59 | 60 | def plan(self, init_obs, max_tries=30): 61 | self.unpack_obs(init_obs) 62 | 63 | if self.map[self.agent_pos[0], self.agent_pos[1]] == 1: #not carrying 64 | self.path = [3] 65 | else: 66 | angle_list = [0, 2, 1, 3] 67 | angle_list.remove(0) 68 | goto_angle = None 69 | finish = False 70 | tries = 0 71 | while not finish: 72 | search_angle = angle_list.pop(0) 73 | _drop, _goto = self.can_drop(search_angle) 74 | tries += 1 75 | if _drop: 76 | self.update_path(search_angle) 77 | self.path = self.path2action(self.path_prefix) + [4] + self.path2action(self.path_suffix) + [3] 78 | finish = True 79 | else: 80 | # since there is only 1 door, there is at most 1 angle can go to but cannot drop 81 | if _goto: 82 | goto_angle = search_angle 83 | 84 | if len(angle_list) == 0: 85 | if goto_angle or tries < max_tries: 86 | self.update_path(goto_angle, forward=True) 87 | self.agent_dir = (self.agent_dir + goto_angle) % 4 88 | self.agent_pos = self.agent_pos + DIR_TO_VEC[self.agent_dir] 89 | angle_list = [0, 2, 1, 3] 90 | angle_list.remove(2) # not search backward 91 | goto_angle = None 92 | else: 93 | finish = True 94 | self.path = [] 95 | print("path not found!") 96 | 97 | def can_drop(self, angle, distance=1): 98 | target_dir = (self.agent_dir + angle) % 4 99 | target_pos = self.agent_pos + DIR_TO_VEC[target_dir] * distance 100 | target_obj = self.map[target_pos[0], target_pos[1]] 101 | if target_obj != 1: # not empty 102 | _drop, _goto = False, False 103 | else: 104 | _drop, _goto = True, True 105 | for i in range(4): 106 | nearby_pos = target_pos + DIR_TO_VEC[i] 107 | if self.map[nearby_pos[0], nearby_pos[1]] == 4: # near a door 108 | _drop = False 109 | return _drop, _goto 110 | 111 | def update_path(self, angle, forward=False): 112 | if forward: 113 | if angle == 2: 114 | self.path_prefix += [2, 'f'] 115 | self.path_suffix = [2, 'f'] + self.path_suffix 116 | elif angle == 1: 117 | self.path_prefix += [1, 'f'] 118 | self.path_suffix = [2, 'f', 1] + self.path_suffix 119 | elif angle == 3: 120 | self.path_prefix += [3, 'f'] 121 | self.path_suffix = [2, 'f', 3] + self.path_suffix 122 | else: 123 | self.path_prefix += ['f'] 124 | self.path_suffix = [2, 'f', 2] + self.path_suffix 125 | else: 126 | if angle == 2: 127 | self.path_prefix += [2] 128 | self.path_suffix = [2] + self.path_suffix 129 | elif angle == 1: 130 | self.path_prefix += [1] 131 | self.path_suffix = [3] + self.path_suffix 132 | elif angle == 3: 133 | self.path_prefix += [3] 134 | self.path_suffix = [1] + self.path_suffix 135 | else: 136 | pass 137 | 138 | def path2action(self, path): 139 | angle = 0 140 | action_list = [] 141 | path.append('f') 142 | for i in path: 143 | if i == 'f': 144 | angle = angle % 4 145 | if angle == 1: 146 | action_list.append(1) 147 | elif angle == 3: 148 | action_list.append(0) 149 | elif angle == 2: 150 | action_list.extend([0, 0]) 151 | else: 152 | pass 153 | angle = 0 154 | action_list.append(2) 155 | else: 156 | angle += i 157 | return action_list[:-1] 158 | 159 | def step(self, obs): 160 | action = self.path.pop(0) 161 | done = len(self.path) == 0 162 | return action, done 163 | 164 | 165 | class Toggle(BaseSkill): 166 | def __init__(self): 167 | pass 168 | 169 | def step(self, obs): 170 | return 5, True 171 | -------------------------------------------------------------------------------- /skill/explore.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import gymnasium as gym 5 | import minigrid 6 | from .base_skill import BaseSkill 7 | 8 | from utils import global_param 9 | class Explore(BaseSkill): 10 | def __init__(self, init_obs, view_size, exp=None): 11 | ''' 12 | Inputs: 13 | init_obs: {'image': width x height x channel, 'position': [x,y], 'direction': , 'mission': ,} 14 | view_size: env.agent_view_size 15 | ''' 16 | 17 | self.view_size = view_size 18 | self.scope = self.view_size // 2 19 | self.reset_to_NW = exp.reset_to_NW if exp is not None else False 20 | init_obs = init_obs[:,:,-4:] 21 | self.path = exp.path if exp is not None else self.plan(init_obs) 22 | self.unpack_obs(init_obs) 23 | # self.reset_to_NW = global_param.get_value('reset_to_NW') 24 | 25 | # self.path = global_param.get_value('path') 26 | #self.path = self.plan(init_obs) 27 | #self.reset_to_NW = False 28 | 29 | def plan(self, init_obs): 30 | self.unpack_obs(init_obs) 31 | if self.agent_dir == 0: # turn until facing west 32 | path = [0, 0] 33 | elif self.agent_dir == 1: 34 | path = [1] 35 | elif self.agent_dir == 2: 36 | path = [] 37 | else: 38 | path = [0] 39 | return path 40 | 41 | def get_fwd_obj(self): 42 | if self.agent_dir == 0: 43 | fwd_obj = self.map[self.agent_pos[0] + self.scope, self.agent_pos[1]] 44 | elif self.agent_dir == 2: 45 | fwd_obj = self.map[max(0, self.agent_pos[0] - self.scope), self.agent_pos[1]] 46 | elif self.agent_dir == 1: 47 | fwd_obj = self.map[self.agent_pos[0], self.agent_pos[1] + self.scope] 48 | else: 49 | fwd_obj = self.map[self.agent_pos[0], max(0, self.agent_pos[1] - self.scope)] 50 | return fwd_obj 51 | 52 | def step(self, obs): 53 | obs = obs[:,:,-4:] 54 | self.unpack_obs(obs) 55 | fwd_obj = self.get_fwd_obj() 56 | 57 | if self.reset_to_NW == False: 58 | if len(self.path) > 0: 59 | action = self.path.pop(0) 60 | 61 | elif fwd_obj == 2 or fwd_obj == 4: # hitting a wall or hitting a door 62 | action = 1 # west -> north -> east 63 | if self.agent_dir == 3: # facing north 64 | self.reset_to_NW = True 65 | else: 66 | action = 2 67 | else: 68 | if fwd_obj == 2 or fwd_obj == 4: # hitting a wall 69 | if self.agent_dir == 0: # facing east 70 | self.path = [2] * self.view_size + [1] # go south for self.view_size step, then turn west 71 | action = 1 # east -> south 72 | elif self.agent_dir == 2: # facing west 73 | self.path = [2] * self.view_size + [0] # go south for self.view_size step, then turn east 74 | action = 0 # west -> south 75 | elif self.agent_dir == 1: # facing south 76 | if self.path == []: 77 | action = np.random.choice([0,1]) 78 | else: 79 | action = self.path[-1] 80 | self.path = [] 81 | else: 82 | action = 1 83 | self.path = [] 84 | elif len(self.path) > 0: 85 | action = self.path.pop(0) 86 | else: 87 | action = 2 88 | done = self.done_check() 89 | if done: 90 | global_param.set_value('explore_done', True) 91 | 92 | return action, done 93 | 94 | def done_check(self): 95 | done = False 96 | ## Position the boundary, because the room is rectangular 97 | wall = np.argwhere(self.map == 2) 98 | door = np.argwhere(self.map == 4) 99 | boundary = np.concatenate((wall, door)) 100 | ## if exploer over, wall must is loop in grid 101 | if FindLoop(boundary): 102 | lower_left = boundary.min(axis=0) 103 | upper_right = boundary.max(axis=0) 104 | room = self.map[lower_left[0]:upper_right[0]+1,lower_left[1]:upper_right[1]+1] 105 | done = np.count_nonzero(room == 0) == 0 # no unseen 106 | return done 107 | 108 | 109 | # DFS way to find loop 110 | def FindLoop(boundary): 111 | if boundary.size == 0: 112 | return False 113 | max_x,max_y = boundary.max(axis=0) 114 | boundary = boundary.tolist() 115 | 116 | visited = set() 117 | def dfs(x,y, parent): 118 | # return True if find loop 119 | if (x,y) in visited: 120 | return True 121 | visited.add((x,y)) 122 | for direct in [(-1, 0), (0, -1), (1, 0), (0, 1)]: 123 | new_x = x + direct[0] 124 | new_y = y + direct[1] 125 | if new_x < 0 or new_x > max_x or new_y < 0 or new_y > max_y: 126 | continue 127 | if (new_x, new_y) != parent and [new_x,new_y] in boundary: 128 | if dfs(new_x, new_y, (x,y)): 129 | return True 130 | return False 131 | x0 = boundary[0][0] 132 | y0 = boundary[0][1] 133 | loop = dfs(x0, y0, None) 134 | return loop 135 | 136 | 137 | if __name__ == "__main__": 138 | import numpy as np 139 | visited = set() 140 | loop = False 141 | boundary = np.array([[0,0],[0,1],[0,2],[1,0],[1,2],[2,2],[2,1],[2,0]]) 142 | print(FindLoop(boundary)) 143 | 144 | 145 | -------------------------------------------------------------------------------- /skill/goto_goal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import gymnasium as gym 5 | import minigrid 6 | from .base_skill import BaseSkill 7 | 8 | ''' 9 | OBJECT_TO_IDX = { 10 | "unseen": 0, 11 | "empty": 1, 12 | "wall": 2, 13 | "floor": 3, 14 | "door": 4, 15 | "key": 5, 16 | "ball": 6, 17 | "box": 7, 18 | "goal": 8, 19 | "lava": 9, 20 | "agent": 10, 21 | } 22 | 23 | STATE_TO_IDX = { 24 | "open": 0, 25 | "closed": 1, 26 | "locked": 2, 27 | } 28 | ''' 29 | 30 | DIRECTION = { 31 | 0: [1, 0], 32 | 1: [0, 1], 33 | 2: [-1, 0], 34 | 3: [0, -1], 35 | } 36 | 37 | 38 | def check_go_through(pos, maps): 39 | x, y = pos 40 | width, height, _ = maps.shape 41 | if x<0 or x>=width or y<0 or y>=height: 42 | return False 43 | return (maps[x, y, 0] in [1, 8] or (maps[x, y, 0] == 4 and maps[x, y, 2]==0) ) 44 | 45 | def get_neighbors(pos_and_dir, maps): 46 | x, y, direction = pos_and_dir 47 | next_dir_left = direction - 1 if direction > 0 else 3 48 | next_dir_right = direction + 1 if direction < 3 else 0 49 | neighbor_list = [(x,y,next_dir_left), (x,y,next_dir_right)] 50 | forward_x, forward_y = DIRECTION[direction] 51 | new_x,new_y = (x+forward_x, y+forward_y) 52 | 53 | if check_go_through((new_x,new_y), maps): 54 | neighbor_list.append((new_x, new_y, direction)) 55 | 56 | 57 | assert not len(neighbor_list)==0 58 | 59 | return neighbor_list 60 | 61 | 62 | class GoTo_Goal(BaseSkill): 63 | def __init__(self, init_obs, target_pos): 64 | ''' 65 | Inputs: 66 | init_obs: {'image': width x height x 4 , 'mission': ,} 67 | target_pos: (x,y) 68 | ''' 69 | obs = init_obs[:,:,-4:] ## why here init_obs is [:,:,0] 70 | self.unpack_obs(obs) 71 | self.obs = obs 72 | x, y = target_pos 73 | target_pos_and_dir = [(x-1, y, 0), (x, y-1, 1), (x+1, y, 2), (x, y+1, 3)] 74 | self.path = self.plan(target_pos_and_dir) 75 | # print(self.path[0], self.path[-1]) 76 | 77 | def plan(self, target_pos_and_dir): 78 | start_node = (self.agent_pos[0], self.agent_pos[1], self.agent_dir) 79 | 80 | open_list = set([start_node]) 81 | closed_list = set([]) 82 | 83 | g = {} 84 | g[start_node] = 0 85 | 86 | parents = {} 87 | parents[start_node] = start_node 88 | 89 | while len(open_list) > 0: 90 | n = None 91 | 92 | for v in open_list: 93 | if n is None or g[v] < g[n]: 94 | n = v 95 | 96 | if n == None: 97 | print('No path found!!') 98 | return None 99 | 100 | ### reconstruct and return the path when the node is the goal position 101 | if n in target_pos_and_dir: 102 | reconst_path = [] 103 | while parents[n] != n: 104 | reconst_path.append(n) 105 | n = parents[n] 106 | 107 | reconst_path.append(start_node) 108 | reconst_path.reverse() 109 | return reconst_path 110 | 111 | for m in get_neighbors(n, self.obs): 112 | if m not in open_list and m not in closed_list: 113 | open_list.add(m) 114 | parents[m] = n 115 | g[m] = g[n] + 1 116 | 117 | else: 118 | if g[m] > g[n]+1: 119 | g[m] = g[n]+1 120 | parents[m] = n 121 | 122 | if m in closed_list: 123 | closed_list.remove(m) 124 | open_list.add(m) 125 | 126 | open_list.remove(n) 127 | closed_list.add(n) 128 | 129 | print('No path found!!!') 130 | print(self.obs[:,:,0].T) 131 | print(self.obs[:,:,3].T) 132 | print(target_pos_and_dir) 133 | return [] 134 | 135 | def step(self, obs=None): 136 | if obs is not None: 137 | obs = obs[:,:,-4:] 138 | self.unpack_obs(obs) 139 | cur_pos, cur_dir = self.agent_pos, self.agent_dir 140 | cur_agent = self.path.pop(0) 141 | assert cur_pos[0] == cur_agent[0] 142 | assert cur_pos[1] == cur_agent[1] 143 | assert cur_dir == cur_agent[2] 144 | else: 145 | cur_dir = self.path.pop(0)[2] 146 | 147 | next_dir = self.path[0][2] 148 | angle = (cur_dir - next_dir) % 4 149 | if angle == 1: 150 | action = 0 151 | elif angle == 3: 152 | action = 1 153 | elif angle == 0: 154 | action = 2 155 | else: 156 | print('No action find error!!!') 157 | action = None 158 | 159 | done = self.done_check() 160 | return action, done 161 | 162 | def done_check(self): 163 | return len(self.path)<=1 -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .env import * 2 | from .eval import * 3 | from .format import * 4 | from .log import * 5 | from .logo import * 6 | from .global_param import * -------------------------------------------------------------------------------- /utils/env.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | import numpy as np 3 | from collections import deque 4 | from gymnasium import spaces 5 | 6 | def make_env_fn(env_key, render_mode=None, frame_stack=1): 7 | def _f(): 8 | env = gym.make(env_key, render_mode=render_mode) 9 | if frame_stack > 1: 10 | env = FrameStack(env, frame_stack) 11 | return env 12 | return _f 13 | 14 | 15 | # Gives a vectorized interface to a single environment 16 | class WrapEnv: 17 | def __init__(self, env_fn): 18 | self.env = env_fn() 19 | 20 | def __getattr__(self, attr): 21 | return getattr(self.env, attr) 22 | 23 | def step(self, action): 24 | if action.ndim == 1: 25 | env_return = self.env.step(action) 26 | else: 27 | env_return = self.env.step(action[0]) 28 | if len(env_return) == 4: 29 | state, reward, done, info = env_return 30 | else: 31 | state, reward, done, _, info = env_return 32 | if isinstance(state, dict): 33 | state = state['image'] 34 | return np.array([state]), np.array([reward]), np.array([done]), np.array([info]) 35 | 36 | 37 | def render(self): 38 | self.env.render() 39 | 40 | def reset(self, seed=None): 41 | state, *_ = self.env.reset(seed=seed) 42 | if isinstance(state, tuple): 43 | ## gym state is tuple type 44 | return np.array([state[0]]) 45 | elif isinstance(state, dict): 46 | ## minigrid state is dict type 47 | return np.array([state['image']]) 48 | else: 49 | return np.array([state]) 50 | 51 | 52 | class FrameStack(gym.Wrapper): 53 | def __init__(self, env, k): 54 | """Stack k last frames. 55 | Returns lazy array, which is much more memory efficient. 56 | See Also 57 | -------- 58 | baselines.common.atari_wrappers.LazyFrames 59 | """ 60 | gym.Wrapper.__init__(self, env) 61 | self.k = k 62 | self.frames = deque([], maxlen=k) 63 | shp = env.observation_space['image'].shape 64 | self.observation_space = spaces.Box( 65 | low=0, 66 | high=255, 67 | shape=(shp[:-1] + (shp[-1] * k,)), 68 | dtype=env.observation_space['image'].dtype) 69 | 70 | def reset(self, seed=None): 71 | ob = self.env.reset(seed=seed)[0] 72 | ob = ob['image'] 73 | for _ in range(self.k): 74 | self.frames.append(ob) 75 | return self._get_ob(), {} 76 | 77 | def step(self, action): 78 | ob, reward, terminated, truncated, info = self.env.step(action) 79 | # ob, reward, done, info = self.env.step(action) 80 | ob = ob['image'] 81 | self.frames.append(ob) 82 | return self._get_ob(), reward, terminated, truncated, info 83 | 84 | def _get_ob(self): 85 | assert len(self.frames) == self.k 86 | return np.concatenate(self.frames, axis=-1) -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : eval.py 5 | @Time : 2023/05/24 09:51:17 6 | @Author : Hu Bin 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | import env 11 | import numpy as np 12 | class Eval(): 13 | def __init__(self, args, policy=None): 14 | self.args = args 15 | self.policy = policy 16 | 17 | 18 | def eval_policy(self, test_num): 19 | 20 | print("env name: %s for %s" %(self.args.task, self.args.save_name)) 21 | game = env.Game(self.args,self.policy) 22 | game.reset() 23 | 24 | reward = [] 25 | game_reward = [] 26 | lens = [] 27 | interactions = [] 28 | fail = 0 29 | for i in range(test_num): 30 | ## eval_policy ## 31 | if game.seed == None: 32 | eval_reward, eval_len, eval_interactions, eval_game_reward = game.eval(game.env_fn, seed = i, show_dialogue=game.show_dialogue) 33 | else: 34 | eval_reward, eval_len, eval_interactions, eval_game_reward = game.eval(game.env_fn, seed = game.seed, show_dialogue=game.show_dialogue) 35 | print("task %s, reward %s, len %s, interaction %s, reward w/o comm penalty %s" %(i, eval_reward,eval_len, eval_interactions, eval_game_reward)) 36 | # if eval_reward <= 0: 37 | if eval_len == 100: 38 | fail += 1 39 | 40 | reward.append(eval_reward) 41 | game_reward.append(eval_game_reward) 42 | lens.append(eval_len) 43 | interactions.append(eval_interactions) 44 | 45 | print("Mean reward:", np.mean(reward)) 46 | print("Mean reward w/o comm penalty:", np.mean(game_reward)) 47 | print("Mean len:", np.mean(lens)) 48 | print("Mean interactions:", np.mean(interactions)) 49 | print("Planning success rate:", 1.- fail/test_num) 50 | 51 | def eval_RL_policy(self, test_num): 52 | 53 | print("env name: %s for %s" %(self.args.task, self.args.save_name)) 54 | game = env.Game_RL(self.args,self.policy) 55 | game.reset() 56 | 57 | reward = [] 58 | lens = [] 59 | 60 | fail = 0 61 | for i in range(test_num): 62 | ## eval_policy ## 63 | if game.seed == None: 64 | eval_reward, eval_len,_, _ = game.eval(game.env_fn, seed = i, show_dialogue=game.show_dialogue) 65 | else: 66 | eval_reward, eval_len,_,_ = game.eval(game.env_fn, seed = game.seed, show_dialogue=game.show_dialogue) 67 | print("task %s, reward %s, len %s" %(i, eval_reward,eval_len)) 68 | # if eval_reward <= 0: 69 | if eval_len == 100: 70 | fail += 1 71 | 72 | reward.append(eval_reward) 73 | lens.append(eval_len) 74 | 75 | print("Mean reward:", np.mean(reward)) 76 | print("Mean len:", np.mean(lens)) 77 | print("Planning success rate:", 1.- fail/test_num) 78 | 79 | def eval_multi_heads_policy(self, test_num): 80 | 81 | print("env name: %s for %s" %(self.args.task, self.args.save_name)) 82 | game = env.Game_multi_heads(self.args,self.policy) 83 | game.reset() 84 | 85 | reward = [] 86 | game_reward = [] 87 | lens = [] 88 | interactions = [] 89 | fail = 0 90 | for i in range(test_num): 91 | ## eval_policy ## 92 | if game.seed == None: 93 | eval_reward, eval_len, eval_interactions, eval_game_reward = game.eval(game.env_fn, seed = i, show_dialogue=game.show_dialogue) 94 | else: 95 | eval_reward, eval_len, eval_interactions, eval_game_reward = game.eval(game.env_fn, seed = game.seed, show_dialogue=game.show_dialogue) 96 | print("task %s, reward %s, len %s, interaction %s, reward w/o comm penalty %s" %(i, eval_reward,eval_len, eval_interactions, eval_game_reward)) 97 | # if eval_reward <= 0: 98 | if eval_len == 100: 99 | fail += 1 100 | 101 | reward.append(eval_reward) 102 | game_reward.append(eval_game_reward) 103 | lens.append(eval_len) 104 | interactions.append(eval_interactions) 105 | 106 | 107 | print("Mean reward:", np.mean(reward)) 108 | print("Mean reward w/o comm penalty:", np.mean(game_reward)) 109 | print("Mean len:", np.mean(lens)) 110 | print("Mean interactions:", np.mean(interactions)) 111 | print("Planning success rate:", 1.- fail/test_num) 112 | 113 | def eval_baseline(self, test_num): 114 | print("env name: %s for %s" %(self.args.task, "baseline")) 115 | game = env.Game(self.args) 116 | game.reset() 117 | reward = [] 118 | game_reward = [] 119 | lens = [] 120 | interactions = [] 121 | fail = 0 122 | for i in range(test_num): 123 | ## eval_policy ## 124 | if game.seed == None: 125 | eval_reward, eval_len, eval_interactions, eval_game_reward = game.baseline_eval(game.env_fn, seed = i, show_dialogue=game.show_dialogue) 126 | else: 127 | eval_reward, eval_len, eval_interactions, eval_game_reward = game.baseline_eval(game.env_fn, seed = game.seed, show_dialogue=game.show_dialogue) 128 | print("task %s, reward %s, len %s, interaction %s, reward w/o comm penalty %s" %(i, eval_reward,eval_len, eval_interactions, eval_game_reward)) 129 | # if eval_reward <= 0: 130 | if eval_len == 100: 131 | fail += 1 132 | 133 | reward.append(eval_reward) 134 | game_reward.append(eval_game_reward) 135 | lens.append(eval_len) 136 | interactions.append(eval_interactions) 137 | 138 | print("Mean reward:", np.mean(reward)) 139 | print("Mean reward w/o comm penalty:", np.mean(game_reward)) 140 | print("Mean len:", np.mean(lens)) 141 | print("Mean interactions:", np.mean(interactions)) 142 | print("Planning success rate:", 1.- fail/test_num) 143 | 144 | def eval_always_ask(self, test_num): 145 | print("env name: %s for %s" %(self.args.task, "always_ask")) 146 | game = env.Game(self.args) 147 | game.reset() 148 | reward = [] 149 | game_reward = [] 150 | lens = [] 151 | interactions = [] 152 | fail = 0 153 | for i in range(test_num): 154 | ## eval_policy ## 155 | if game.seed == None: 156 | eval_reward, eval_len, eval_interactions, eval_game_reward = game.ask_eval(game.env_fn, seed = i, show_dialogue=game.show_dialogue) 157 | else: 158 | eval_reward, eval_len, eval_interactions, eval_game_reward = game.ask_eval(game.env_fn, seed = game.seed, show_dialogue=game.show_dialogue) 159 | print("task %s, reward %s, len %s, interaction %s, reward w/o comm penalty %s" %(i, eval_reward,eval_len, eval_interactions, eval_game_reward)) 160 | # if eval_reward <= 0: 161 | if eval_len == 100: 162 | fail += 1 163 | 164 | reward.append(eval_reward) 165 | game_reward.append(eval_game_reward) 166 | lens.append(eval_len) 167 | interactions.append(eval_interactions) 168 | 169 | print("Mean reward:", np.mean(reward)) 170 | print("Mean reward w/o comm penalty:", np.mean(game_reward)) 171 | print("Mean len:", np.mean(lens)) 172 | print("Mean interactions:", np.mean(interactions)) 173 | print("Planning success rate:", 1.- fail/test_num) 174 | if __name__ == '__main__': 175 | pass 176 | 177 | 178 | -------------------------------------------------------------------------------- /utils/format.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy 4 | import re 5 | import torch 6 | import torch_ac 7 | import gymnasium as gym 8 | 9 | import utils 10 | 11 | 12 | def get_obss_preprocessor(obs_space): 13 | # Check if obs_space is an image space 14 | if isinstance(obs_space, gym.spaces.Box): 15 | obs_space = {"image": obs_space.shape} 16 | 17 | def preprocess_obss(obss, device=None): 18 | return torch_ac.DictList({ 19 | "image": preprocess_images(obss, device=device) 20 | }) 21 | 22 | # Check if it is a MiniGrid observation space 23 | elif isinstance(obs_space, gym.spaces.Dict) and "image" in obs_space.spaces.keys(): 24 | obs_space = {"image": obs_space.spaces["image"].shape, "text": 100} 25 | 26 | vocab = Vocabulary(obs_space["text"]) 27 | 28 | def preprocess_obss(obss, device=None): 29 | return torch_ac.DictList({ 30 | "image": preprocess_images([obs["image"] for obs in obss], device=device), 31 | "text": preprocess_texts([obs["mission"] for obs in obss], vocab, device=device) 32 | }) 33 | 34 | preprocess_obss.vocab = vocab 35 | 36 | else: 37 | raise ValueError("Unknown observation space: " + str(obs_space)) 38 | 39 | return obs_space, preprocess_obss 40 | 41 | 42 | def preprocess_images(images, device=None): 43 | # Bug of Pytorch: very slow if not first converted to numpy array 44 | images = numpy.array(images) 45 | return torch.tensor(images, device=device, dtype=torch.float) 46 | 47 | 48 | def preprocess_texts(texts, vocab, device=None): 49 | var_indexed_texts = [] 50 | max_text_len = 0 51 | 52 | for text in texts: 53 | tokens = re.findall("([a-z]+)", text.lower()) 54 | var_indexed_text = numpy.array([vocab[token] for token in tokens]) 55 | var_indexed_texts.append(var_indexed_text) 56 | max_text_len = max(len(var_indexed_text), max_text_len) 57 | 58 | indexed_texts = numpy.zeros((len(texts), max_text_len)) 59 | 60 | for i, indexed_text in enumerate(var_indexed_texts): 61 | indexed_texts[i, :len(indexed_text)] = indexed_text 62 | 63 | return torch.tensor(indexed_texts, device=device, dtype=torch.long) 64 | 65 | 66 | class Vocabulary: 67 | """A mapping from tokens to ids with a capacity of `max_size` words. 68 | It can be saved in a `vocab.json` file.""" 69 | 70 | def __init__(self, max_size): 71 | self.max_size = max_size 72 | self.vocab = {} 73 | 74 | def load_vocab(self, vocab): 75 | self.vocab = vocab 76 | 77 | def __getitem__(self, token): 78 | if not token in self.vocab.keys(): 79 | if len(self.vocab) >= self.max_size: 80 | raise ValueError("Maximum vocabulary capacity reached") 81 | self.vocab[token] = len(self.vocab) + 1 82 | return self.vocab[token] 83 | -------------------------------------------------------------------------------- /utils/global_param.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : global_param.py 5 | @Time : 2023/05/23 10:55:41 6 | @Author : Hu Bin 7 | @Version : 1.0 8 | @Desc : global parameter dictory for skill 9 | ''' 10 | 11 | def init(): 12 | global _global_dict 13 | _global_dict = {} 14 | 15 | def set_value(key, value): 16 | _global_dict[key] = value 17 | 18 | def get_value(key, defValue=None): 19 | try: 20 | return _global_dict[key] 21 | except KeyError: 22 | return defValue -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import os, pickle 3 | import json 4 | 5 | class color: 6 | BOLD = '\033[1m\033[48m' 7 | END = '\033[0m' 8 | ORANGE = '\033[38;5;202m' 9 | BLACK = '\033[38;5;240m' 10 | 11 | 12 | def create_logger(args): 13 | from torch.utils.tensorboard import SummaryWriter 14 | """Use hyperparms to set a directory to output diagnostic files.""" 15 | 16 | arg_dict = args.__dict__ 17 | 18 | assert "logdir" in arg_dict, \ 19 | "You must provide a 'logdir' key in your command line arguments." 20 | 21 | arg_dict = OrderedDict(sorted(arg_dict.items(), key=lambda t: t[0])) 22 | logdir = str(arg_dict.pop('logdir')) 23 | output_dir = os.path.join(logdir, args.policy, args.task, args.save_name) 24 | 25 | os.makedirs(output_dir, exist_ok=True) 26 | 27 | # Create a file with all the hyperparam settings in plaintext 28 | info_path = os.path.join(output_dir, "config.json") 29 | 30 | with open(info_path,'wt') as f: 31 | json.dump(arg_dict, f, indent=4) 32 | 33 | logger = SummaryWriter(output_dir, flush_secs=0.1) 34 | print("Logging to " + color.BOLD + color.ORANGE + str(output_dir) + color.END) 35 | 36 | logger.name = args.save_name 37 | logger.dir = output_dir 38 | return logger 39 | 40 | def parse_previous(args): 41 | if args.previous is not None: 42 | run_args = pickle.load(open(args.previous + "config.json", "rb")) 43 | args.recurrent = run_args.recurrent 44 | args.env_name = run_args.env_name 45 | args.command_profile = run_args.command_profile 46 | args.input_profile = run_args.input_profile 47 | args.learn_gains = run_args.learn_gains 48 | args.traj = run_args.traj 49 | args.no_delta = run_args.no_delta 50 | args.ik_baseline = run_args.ik_baseline 51 | if args.exchange_reward is not None: 52 | args.reward = args.exchange_reward 53 | args.run_name = run_args.run_name + "_NEW-" + args.reward 54 | else: 55 | args.reward = run_args.reward 56 | args.run_name = run_args.run_name + "--cont" 57 | return args 58 | -------------------------------------------------------------------------------- /utils/logo.py: -------------------------------------------------------------------------------- 1 | class color: 2 | BOLD = '\033[1m\033[48m' 3 | END = '\033[0m' 4 | ORANGE = '\033[38;5;202m' 5 | BLACK = '\033[38;5;240m' 6 | 7 | def print_logo(subtitle=""): 8 | print(color.BOLD, end="") 9 | print(color.ORANGE, end="") 10 | print(" #@@# ") 11 | print(" #@ @# ") 12 | print("#@@# #@@# #@@# #@@# #@ @# #@@@@@@@@@@@@# #@@# ") 13 | print("#@@# #@@# #@@@# #@@@# #@ @# #@@@@@@@@@@@@@@@# #@@# ") 14 | print("#@@# #@@# #@@#@# #@#@@# #@@@@@@@@@# #@@# @@# #@@# ") 15 | print("#@@# #@@# #@@ #@# #@# @@# @# #@@# @@# #@@# ") 16 | print("#@@# #@@# #@@ #@# #@# @@# @# #@@# @@# #@@# ") 17 | print("#@@# #@@# #@@ #@# #@# @@# @# #@@# @@# #@@# ") 18 | print("#@@# #@@# #@@ #@# #@# @@# #@@@@@@@@@@@@@@@@# #@@# ") 19 | print("#@@# #@@# #@@ #@##@# @@# #@@@@@@@@@@@@@@@# #@@# ") 20 | print("#@@# #@@# #@@ #@@# @@# #@@# @@# #@@# ") 21 | print("#@@# #@@# #@@ @@# #@@# @@# #@@# ") 22 | print("#@@# #@@# #@@ @@# #@@# @@# #@@# ") 23 | print("#@@# #@@# #@@ @@# #@@# @@# #@@# ") 24 | print("#@@# #@@# #@@ @@# #@@# @@# #@@# ") 25 | print("#@@@@@@@@@@@@@@@@@# #@@@@@@@@@@@@@@@@@# #@@ @@# #@@# @@# #@@@@@@@@@@@@@@@@@#") 26 | print("#@@@@@@@@@@@@@@@@@# #@@@@@@@@@@@@@@@@@# #@@ @@# #@@# @@# #@@@@@@@@@@@@@@@@@#") 27 | 28 | 29 | print(color.END) 30 | print(subtitle + "\n\n") 31 | --------------------------------------------------------------------------------