├── tech ├── __init__.py ├── technode.py └── techtree.py ├── images └── deckard.png ├── .gitmodules ├── .gitignore ├── requirements.txt ├── download_model.sh ├── tasks ├── base │ ├── __init__.py │ ├── reward_wrapper.py │ ├── clip_wrapper.py │ ├── success_wrapper.py │ ├── vpt_wrapper.py │ ├── terminal_wrapper.py │ ├── clip_reward.py │ └── techtree_wrapper.py ├── task_specs.yaml ├── __init__.py └── minedojo │ ├── __init__.py │ └── wrappers.py ├── sb3_vpt ├── types.py ├── logging.py ├── policy.py ├── buffer.py └── algorithm.py ├── README.md ├── techtree.py ├── subtask.py └── main.py /tech/__init__.py: -------------------------------------------------------------------------------- 1 | TECH_ORDER = ["wooden", "stone", "iron", "golden", "diamond"] 2 | -------------------------------------------------------------------------------- /images/deckard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeckardAgent/deckard/HEAD/images/deckard.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "VPT"] 2 | path = VPT 3 | url = https://github.com/kolbytn/VPT-Policy.git 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | figures/ 2 | models/ 3 | *venv/ 4 | weights/ 5 | logs/ 6 | results*/ 7 | trajectories/ 8 | __pycache__ 9 | .vscode/ 10 | output.txt 11 | temp* -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboard==2.10.1 2 | gym3==0.3.3 3 | torch==1.12.1 4 | sb3-contrib==1.6.2 5 | stable-baselines3==1.6.2 6 | git+https://github.com/MineDojo/MineCLIP # Installs both Mineclip and Minedojo -------------------------------------------------------------------------------- /download_model.sh: -------------------------------------------------------------------------------- 1 | mkdir -p models 2 | mkdir -p weights 3 | 4 | wget https://openaipublic.blob.core.windows.net/minecraft-rl/models/foundation-model-3x.model -O models/3x.model 5 | wget https://openaipublic.blob.core.windows.net/minecraft-rl/models/bc-house-3x.weights -O weights/bc-house-3x.weights 6 | -------------------------------------------------------------------------------- /tasks/base/__init__.py: -------------------------------------------------------------------------------- 1 | from tasks.base.clip_wrapper import ClipWrapper 2 | from tasks.base.clip_reward import ClipReward 3 | from tasks.base.techtree_wrapper import TechTreeWrapper 4 | from tasks.base.reward_wrapper import RewardWrapper 5 | from tasks.base.success_wrapper import SuccessWrapper 6 | from tasks.base.terminal_wrapper import TerminalWrapper 7 | from tasks.base.vpt_wrapper import VPTWrapper 8 | -------------------------------------------------------------------------------- /sb3_vpt/types.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, Tuple 2 | import torch as th 3 | 4 | 5 | class VPTStates(NamedTuple): 6 | mask: Tuple[th.Tensor, ...] 7 | keys: Tuple[th.Tensor, ...] 8 | values: Tuple[th.Tensor, ...] 9 | 10 | 11 | class VPTRolloutBufferSamples(NamedTuple): 12 | observations: th.Tensor 13 | actions: th.Tensor 14 | old_values: th.Tensor 15 | old_log_prob: th.Tensor 16 | advantages: th.Tensor 17 | returns: th.Tensor 18 | vpt_states: VPTStates 19 | task_id: th.Tensor 20 | episode_starts: th.Tensor 21 | mask: th.Tensor 22 | -------------------------------------------------------------------------------- /tasks/task_specs.yaml: -------------------------------------------------------------------------------- 1 | wooden_pickaxe: 2 | task_id: creative 3 | sim: minedojo 4 | fast_reset: 0 5 | clip_specs: 6 | prompts: [] 7 | reward_specs: 8 | item_rewards: {} 9 | success_specs: 10 | terminal: false 11 | reward: 0 12 | all: {} 13 | terminal_specs: 14 | max_steps: 10000 15 | techtree_specs: 16 | guide_path: data/codex_techtree.json 17 | target_item: wooden_pickaxe 18 | 19 | base_task: 20 | task_id: creative 21 | sim: minedojo 22 | fast_reset: 5 23 | clip_specs: 24 | prompts: [] 25 | reward_specs: 26 | item_rewards: {} 27 | success_specs: 28 | terminal: false 29 | reward: 0 30 | all: 31 | item: 32 | quantity: 1 33 | terminal_specs: 34 | max_steps: 1000 35 | -------------------------------------------------------------------------------- /tasks/base/reward_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from gym import Wrapper 3 | from abc import ABC, abstractstaticmethod 4 | 5 | 6 | class RewardWrapper(Wrapper, ABC): 7 | def __init__(self, env, item_rewards: Dict[str, Dict[str, int]] = dict()): 8 | super().__init__(env) 9 | self.item_rewards = item_rewards 10 | self.last_inventory = None 11 | self.reward_count = None 12 | 13 | def reset(self): 14 | self.last_inventory = {item: 0 for item in self.item_rewards} 15 | self.reward_count = {item: 0 for item in self.item_rewards} 16 | return super().reset() 17 | 18 | def step(self, action): 19 | obs, reward, done, info = super().step(action) 20 | 21 | for item in self.item_rewards: 22 | curr_inv = self._get_item_count(obs, item) 23 | item_diff = curr_inv - self.last_inventory[item] 24 | if "quantity" in self.item_rewards[item]: 25 | item_diff = min(item_diff, self.item_rewards[item]["quantity"] - self.reward_count[item]) 26 | item_reward = self.item_rewards[item]["reward"] if "reward" in self.item_rewards[item] else 1 27 | reward += item_diff * item_reward 28 | self.reward_count[item] += item_diff 29 | self.last_inventory[item] = curr_inv 30 | 31 | return obs, reward, done, info 32 | 33 | @abstractstaticmethod 34 | def _get_item_count(obs, item): 35 | raise NotImplementedError() 36 | -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from tasks.minedojo import make_minedojo 3 | 4 | 5 | CUTOM_TASK_SPECS = OmegaConf.to_container(OmegaConf.load("tasks/task_specs.yaml")) 6 | 7 | 8 | def get_specs(task, **kwargs): 9 | # Get task data and task id 10 | if task in CUTOM_TASK_SPECS: 11 | yaml_specs = CUTOM_TASK_SPECS[task].copy() 12 | task_id = yaml_specs.pop("task_id", task) 13 | assert "sim" in yaml_specs, "task_specs.yaml must define sim attribute" 14 | else: 15 | yaml_specs = dict() 16 | task_id = task 17 | 18 | if "target_item" in kwargs and task == "base_task": 19 | yaml_specs["clip_specs"]["prompts"].append("collect " + kwargs["target_item"]) 20 | yaml_specs["reward_specs"]["item_rewards"][kwargs["target_item"]] = dict(reward=1) 21 | yaml_specs["success_specs"]["all"]["item"]["type"] = kwargs["target_item"] 22 | 23 | # Get minedojo specs 24 | sim_specs = yaml_specs.pop("sim_specs", dict()) 25 | 26 | # Get our task specs 27 | task_specs = dict( 28 | clip=False, 29 | fake_clip=False, 30 | fake_dreamer=False, 31 | subgoals=False, 32 | ) 33 | task_specs.update(**yaml_specs) 34 | task_specs.update(**kwargs) 35 | assert not (task_specs["clip"] and task_specs["fake_clip"]), "Can only use one reward shaper" 36 | 37 | return task_id, task_specs, sim_specs 38 | 39 | 40 | def make(task: str, **kwargs): 41 | # Get our custom task specs 42 | task_id, task_specs, sim_specs = get_specs(task, **kwargs) # Note: additional kwargs end up in task_specs dict 43 | 44 | # Make minedojo env 45 | env = make_minedojo(task_id, task_specs, sim_specs) 46 | 47 | return env 48 | -------------------------------------------------------------------------------- /tasks/base/clip_wrapper.py: -------------------------------------------------------------------------------- 1 | from gym import Wrapper 2 | import torch as th 3 | 4 | 5 | class ClipWrapper(Wrapper): 6 | def __init__(self, env, clip, prompts=None, dense_reward=.01, clip_target=23, clip_min=21, smoothing=50, **kwargs): 7 | super().__init__(env) 8 | self.clip = clip 9 | 10 | assert prompts is not None 11 | self.prompt = prompts 12 | self.dense_reward = dense_reward 13 | self.smoothing = smoothing 14 | self.clip_target = th.tensor(clip_target) 15 | self.clip_min = th.tensor(clip_min) 16 | 17 | self.buffer = None 18 | self._clip_state = None, None 19 | self.last_score = 0 20 | 21 | def reset(self, **kwargs): 22 | self._clip_state = None, self._clip_state[1] 23 | self.buffer = None 24 | self.last_score = 0 25 | return self.env.reset(**kwargs) 26 | 27 | def step(self, action): 28 | obs, reward, done, info = self.env.step(action) 29 | 30 | if len(self.prompt) > 0: 31 | logits, self._clip_state = self.clip.get_logits(obs, self.prompt, self._clip_state) 32 | logits = logits.detach().cpu() 33 | 34 | self.buffer = self._insert_buffer(self.buffer, logits[:1]) 35 | score = self._get_score() 36 | 37 | if score > self.last_score: 38 | reward += self.dense_reward * score 39 | self.last_score = score 40 | 41 | return obs, reward, done, info 42 | 43 | def _get_score(self): 44 | return (max( 45 | th.mean(self.buffer) - self.clip_min, 46 | 0 47 | ) / (self.clip_target - self.clip_min)).item() 48 | 49 | def _insert_buffer(self, buffer, logits): 50 | if buffer is None: 51 | buffer = logits.unsqueeze(0) 52 | elif buffer.shape[0] < self.smoothing: 53 | buffer = th.cat([buffer, logits.unsqueeze(0)], dim=0) 54 | else: 55 | buffer = th.cat([buffer[1:], logits.unsqueeze(0)], dim=0) 56 | return buffer 57 | -------------------------------------------------------------------------------- /tasks/base/success_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from gym import Wrapper 3 | from abc import ABC, abstractstaticmethod 4 | 5 | 6 | class SuccessWrapper(Wrapper, ABC): 7 | def __init__(self, env, terminal: bool = True, reward: int = 0, all: Dict = dict(), any: Dict = dict()): 8 | super().__init__(env) 9 | self.terminal = terminal 10 | self.all_conditions = all 11 | self.any_conditions = any 12 | self.success_reward = reward 13 | 14 | def step(self, action): 15 | obs, reward, done, info = super().step(action) 16 | 17 | info["success"] = info.pop("success", False) 18 | 19 | if len(self.all_conditions) > 0: 20 | info["success"] = info["success"] or all( 21 | self._check_condition(condition_type, condition_info, obs) 22 | for condition_type, condition_info in self.all_conditions.items() 23 | ) 24 | 25 | if len(self.any_conditions) > 0: 26 | info["success"] = info["success"] or any( 27 | self._check_condition(condition_type, condition_info, obs) 28 | for condition_type, condition_info in self.any_conditions.items() 29 | ) 30 | 31 | if self.terminal: 32 | done = done or info["success"] 33 | if info["success"]: 34 | reward += self.success_reward 35 | 36 | return obs, reward, done, info 37 | 38 | def _check_condition(self, condition_type, condition_info, obs): 39 | if condition_type == "item": 40 | return self._check_item_condition(condition_info, obs) 41 | elif condition_type == "blocks": 42 | return self._check_blocks_condition(condition_info, obs) 43 | else: 44 | raise NotImplementedError("{} terminal condition not implemented".format(condition_type)) 45 | 46 | @abstractstaticmethod 47 | def _check_item_condition(condition_info, obs): 48 | raise NotImplementedError() 49 | 50 | @abstractstaticmethod 51 | def _check_blocks_condition(condition_info, obs): 52 | raise NotImplementedError() 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DECKARD Minecraft Agent 2 | 3 | ![image](images/deckard.png) 4 | 5 | The DECKARD Minecraft agent uses knowledge from large language models to assist the exploration for reinforcement learning agents. This repository includes our implementation of the DECKARD agent for collecting and crafting arbitrary items in Minecraft. For additional details about our approach please see our [website](https://deckardagent.github.io/) and paper, [Do Embodied Agents Dream of Electric Sheep?](https://arxiv.org/abs/2301.12050). 6 | 7 | ## Installation 8 | 9 | We use [Minedojo](https://github.com/MineDojo/MineDojo) for agent training and evaluation in Minecraft. Before installing python dependencies for Minedojo, you will need `openjdk-8-jdk` and `python>=3.9`. This [guide](https://docs.minedojo.org/sections/getting_started/install.html#prerequisites) contains additional installation details for the Minedojo simulator. 10 | 11 | If you didn't clone the VPT submodule yet, run: 12 | ``` 13 | git submodule update --init --recursive 14 | ``` 15 | 16 | Next, install python packages: 17 | ``` 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | We finetune our agent from OpenAI's [VPT Minecraft agent](https://github.com/openai/Video-Pre-Training). Download their pretrained weights using our script: 22 | ``` 23 | bash download_model.sh 24 | ``` 25 | 26 | Finally, we use [MineClip](https://github.com/MineDojo/MineCLIP) for reward shaping. Download the weights [here](https://drive.google.com/file/d/1uaZM1ZLBz2dZWcn85rZmjP7LV6Sg5PZW/view?usp=sharing) and place them at `weights/mineclip_attn.pth`. 27 | 28 | ## Usage 29 | 30 | The default way to use DECKARD occasionally pauses exploration to train subtasks. Run this method using: 31 | ``` 32 | python main.py 33 | ``` 34 | 35 | Alternatively, you can pretrain policies for subtasks by running: 36 | ``` 37 | python subtask.py --task base_task --target_item log 38 | ``` 39 | 40 | Then, add the trained subtask checkpoint to your yaml config under `techtree_specs.tasks`: 41 | ``` 42 | my_config: 43 | task_id: creative 44 | sim: minedojo 45 | fast_reset: 0 46 | terminal_specs: 47 | max_steps: 10000 48 | techtree_specs: 49 | guide_path: data/codex_techtree.json 50 | target_item: wooden_pickaxe 51 | tasks: 52 | log: log_checkpoint.zip 53 | ``` 54 | 55 | and run DECKARD for building the Minecraft technology tree using: 56 | ``` 57 | python techtree.py --config my_config 58 | ``` 59 | 60 | Note that Minecraft requires using `xvfb-run` to render on a virtual display when using a headless machine. 61 | -------------------------------------------------------------------------------- /tasks/base/vpt_wrapper.py: -------------------------------------------------------------------------------- 1 | from gym import Wrapper 2 | import gym.spaces as spaces 3 | import numpy as np 4 | from abc import ABC, abstractmethod 5 | 6 | from VPT.agent import AGENT_RESOLUTION, ACTION_TRANSFORMER_KWARGS, resize_image 7 | from VPT.lib.action_mapping import CameraHierarchicalMapping 8 | from VPT.lib.actions import ActionTransformer 9 | 10 | 11 | class VPTWrapper(Wrapper, ABC): 12 | def __init__(self, env, render=False, freeze_equipped=False): 13 | super().__init__(env) 14 | 15 | self.action_mapper = CameraHierarchicalMapping(n_camera_bins=11) 16 | self.action_transformer = ActionTransformer(**ACTION_TRANSFORMER_KWARGS) 17 | 18 | self.observation_space = spaces.Box(0, 255, shape=AGENT_RESOLUTION+(3,)) 19 | self.action_space = spaces.MultiDiscrete([ 20 | space.eltype.n 21 | for space in self.action_mapper.get_action_space_update().values() 22 | ]) 23 | 24 | self.do_render = render 25 | self.freeze_equipped = freeze_equipped 26 | 27 | def reset(self): 28 | obs = self.env.reset() 29 | obs = self._process_obs(obs) 30 | return obs 31 | 32 | def step(self, action): 33 | env_action = self._process_action(action) 34 | obs, reward, done, info = self.env.step(env_action) 35 | if self.do_render: 36 | self.env.render() 37 | 38 | if hasattr(self, "is_successful"): 39 | info["success"] = info.pop("success", False) or self.is_successful 40 | obs = self._process_obs(obs) 41 | return obs, reward, done, info 42 | 43 | def _process_action(self, action): 44 | action = { 45 | "camera": np.expand_dims(action[0], axis=0), 46 | "buttons": np.expand_dims(action[1], axis=0) 47 | } 48 | minerl_action = self.action_mapper.to_factored(action) 49 | minerl_action_transformed = self.action_transformer.policy2env(minerl_action) 50 | if self.freeze_equipped: 51 | for name in ["drop", "swap_slot", "pickItem", "hotbar.1", "hotbar.2", "hotbar.3", "hotbar.4", 52 | "hotbar.5", "hotbar.6", "hotbar.7", "hotbar.8", "hotbar.9"]: 53 | minerl_action_transformed.pop(name, None) 54 | return self._filter_actions(minerl_action_transformed) 55 | 56 | def _process_obs(self, obs): 57 | return resize_image(self._get_curr_frame(obs), AGENT_RESOLUTION) 58 | 59 | @abstractmethod 60 | def _filter_actions(self, actions): 61 | raise NotImplementedError() 62 | 63 | @abstractmethod 64 | def _get_curr_frame(self, obs): 65 | raise NotImplementedError() 66 | -------------------------------------------------------------------------------- /tasks/base/terminal_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from gym import Wrapper 3 | import numpy as np 4 | from abc import ABC, abstractstaticmethod 5 | 6 | 7 | class TerminalWrapper(Wrapper, ABC): 8 | def __init__(self, env, max_steps: int = 500, on_death=True, all: Dict = dict(), any: Dict = dict(), stagger_max_steps=False): 9 | super().__init__(env) 10 | self.max_steps = max_steps 11 | self.on_death = on_death 12 | self.all_conditions = all 13 | self.any_conditions = any 14 | self.t = 0 15 | self.curr_max_steps = self.max_steps 16 | self.stagger_max_steps = stagger_max_steps 17 | 18 | def reset(self): 19 | self.t = 0 20 | if self.stagger_max_steps: 21 | self.curr_max_steps = np.random.randint((self.max_steps*3)//4, self.max_steps+1) 22 | else: 23 | self.curr_max_steps = self.max_steps 24 | return super().reset() 25 | 26 | def step(self, action): 27 | obs, reward, done, info = super().step(action) 28 | 29 | self.t += 1 30 | done = done or self.t >= self.curr_max_steps 31 | 32 | if self.on_death: 33 | done = done or self._check_condition("death", {}, obs) 34 | 35 | if len(self.all_conditions) > 0: 36 | done = done or all( 37 | self._check_condition(condition_type, condition_info, obs) 38 | for condition_type, condition_info in self.all_conditions.items() 39 | ) 40 | 41 | if len(self.any_conditions) > 0: 42 | done = done or any( 43 | self._check_condition(condition_type, condition_info, obs) 44 | for condition_type, condition_info in self.any_conditions.items() 45 | ) 46 | 47 | return obs, reward, done, info 48 | 49 | def _check_condition(self, condition_type, condition_info, obs): 50 | if condition_type == "item": 51 | return self._check_item_condition(condition_info, obs) 52 | elif condition_type == "blocks": 53 | return self._check_blocks_condition(condition_info, obs) 54 | elif condition_type == "death": 55 | return self._check_death_condition(condition_info, obs) 56 | else: 57 | raise NotImplementedError("{} terminal condition not implemented".format(condition_type)) 58 | 59 | @abstractstaticmethod 60 | def _check_item_condition(condition_info, obs): 61 | raise NotImplementedError() 62 | 63 | @abstractstaticmethod 64 | def _check_blocks_condition(condition_info, obs): 65 | raise NotImplementedError() 66 | 67 | @abstractstaticmethod 68 | def _check_death_condition(condition_info, obs): 69 | raise NotImplementedError() -------------------------------------------------------------------------------- /tasks/minedojo/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from omegaconf import OmegaConf 3 | from minedojo.tasks import MetaTaskBase, _meta_task_make, _parse_inventory_dict, ALL_TASKS_SPECS 4 | from minedojo.sim import MineDojoSim 5 | 6 | from tasks.minedojo.wrappers import * 7 | 8 | 9 | def _get_minedojo_specs(task_id, task_specs, sim_specs): 10 | if task_id in ALL_TASKS_SPECS: 11 | minedojo_specs = ALL_TASKS_SPECS[task_id] 12 | if OmegaConf.is_config(minedojo_specs): 13 | minedojo_specs = OmegaConf.to_container(minedojo_specs) 14 | minedojo_specs.pop("prompt", None) 15 | meta_task_cls = minedojo_specs.pop("__cls__") 16 | else: 17 | minedojo_specs = dict() 18 | meta_task_cls = task_id 19 | 20 | minedojo_specs.update(dict( 21 | image_size=(160, 256), 22 | fast_reset=False, 23 | event_level_control=False, 24 | use_voxel=False, 25 | use_lidar=False 26 | )) 27 | 28 | # If using blocks condition, activate voxels 29 | if ("terminal_specs" in task_specs and \ 30 | ("all" in task_specs["terminal_specs"] and any(x == "blocks" for x in task_specs["terminal_specs"]["all"]) or \ 31 | "any" in task_specs["terminal_specs"] and any(x == "blocks" for x in task_specs["terminal_specs"]["any"]))) or \ 32 | ("success_specs" in task_specs and \ 33 | ("all" in task_specs["success_specs"] and any(x == "blocks" for x in task_specs["success_specs"]["all"]) or \ 34 | "any" in task_specs["success_specs"] and any(x == "blocks" for x in task_specs["success_specs"]["any"]))): 35 | minedojo_specs["use_voxel"] = True 36 | minedojo_specs["voxel_size"] = dict(xmin=-3, ymin=-1, zmin=-3, xmax=3, ymax=1, zmax=3) 37 | 38 | minedojo_specs.update(**sim_specs) 39 | 40 | if "initial_inventory" in minedojo_specs: 41 | minedojo_specs["initial_inventory"] = _parse_inventory_dict(minedojo_specs["initial_inventory"]) 42 | 43 | return meta_task_cls, minedojo_specs 44 | 45 | 46 | def _add_wrappers( 47 | env: MetaTaskBase, 48 | task_id: str, 49 | reward_specs: Dict = None, 50 | success_specs: Dict = None, 51 | terminal_specs: Dict = None, 52 | clip_specs: Dict = None, 53 | techtree_specs: Dict = None, 54 | fast_reset: int = None, 55 | log_dir: str = None, 56 | freeze_equipped: bool = False, 57 | **kwargs 58 | ): 59 | if reward_specs: 60 | env = MinedojoRewardWrapper(env, **reward_specs) 61 | if success_specs: 62 | env = MinedojoSuccessWrapper(env, **success_specs) 63 | if terminal_specs is None: 64 | terminal_specs = dict(max_steps=500, on_death=True) 65 | env = MinedojoTerminalWrapper(env, **terminal_specs) 66 | 67 | # Add reward shaping wrapper 68 | if clip_specs is not None: 69 | clip_reward = MinedojoClipReward() 70 | env = ClipWrapper(env, clip_reward, **clip_specs) 71 | 72 | if techtree_specs is not None: 73 | env = MinedojoTechTreeWrapper(env, log_dir=log_dir, **techtree_specs) 74 | 75 | # Add VPT wrapper 76 | env = MinedojoVPTWrapper(env, freeze_equipped=freeze_equipped) 77 | 78 | # If we don't care about start position, use fast reset to speed training and prevent memory leaks 79 | if fast_reset is not None: 80 | wrapped = env 81 | while hasattr(wrapped, "env"): 82 | if isinstance(wrapped.env, MineDojoSim): 83 | wrapped.env = MinedojoSemifastResetWrapper( 84 | wrapped.env, 85 | reset_freq=fast_reset, 86 | random_teleport_range=200 87 | ) 88 | break 89 | wrapped = wrapped.env 90 | 91 | return env 92 | 93 | 94 | def make_minedojo(task_id: str, task_specs, sim_specs): 95 | 96 | # Get minedojo specs 97 | meta_task_cls, minedojo_specs = _get_minedojo_specs(task_id, task_specs, sim_specs) 98 | 99 | # Make minedojo env 100 | env = _meta_task_make(meta_task_cls, **minedojo_specs) 101 | 102 | # Add our wrappers 103 | env = _add_wrappers(env, task_id, **task_specs) 104 | 105 | return env 106 | -------------------------------------------------------------------------------- /tasks/base/clip_reward.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Dict 2 | import torch as th 3 | from omegaconf import OmegaConf 4 | from mineclip import MineCLIP 5 | from abc import ABC, abstractstaticmethod 6 | 7 | 8 | class ClipReward(ABC): 9 | def __init__(self, ckpt="weights/mineclip_attn.pth", **kwargs) -> None: 10 | kwargs["arch"] = kwargs.pop("arch", "vit_base_p16_fz.v2.t2") 11 | kwargs["hidden_dim"] = kwargs.pop("hidden_dim", 512) 12 | kwargs["image_feature_dim"] = kwargs.pop("image_feature_dim", 512) 13 | kwargs["mlp_adapter_spec"] = kwargs.pop("mlp_adapter_spec", "v0-2.t0") 14 | kwargs["pool_type"] = kwargs.pop("pool_type", "attn.d2.nh8.glusw") 15 | kwargs["resolution"] = [160, 256] 16 | 17 | self.resolution = self.get_resolution() 18 | self.device = kwargs.pop("device", "cuda") 19 | self.model = None 20 | 21 | self._load_mineclip(ckpt, kwargs) 22 | 23 | @abstractstaticmethod 24 | def get_resolution(): 25 | raise NotImplementedError() 26 | 27 | @abstractstaticmethod 28 | def _get_curr_frame(obs): 29 | raise NotImplementedError() 30 | 31 | def _load_mineclip(self, ckpt, config): 32 | config = OmegaConf.create(config) 33 | self.model = MineCLIP(**config).to(self.device) 34 | self.model.load_ckpt(ckpt, strict=True) 35 | if self.resolution != (160, 256): # Not ideal, but we need to resize the relative position embedding 36 | self.model.clip_model.vision_model._resolution = th.tensor([160, 256]) # This isn't updated from when mineclip resized it 37 | self.model.clip_model.vision_model.resize_pos_embed(self.resolution) 38 | self.model.eval() 39 | 40 | def _get_reward_from_logits( 41 | self, 42 | logits: th.Tensor # P 43 | ) -> float: 44 | probs = th.softmax(logits, 0) 45 | return max(probs[0].item() - 1 / logits.shape[0], 0) 46 | 47 | def _get_image_feats( 48 | self, 49 | curr_frame: th.Tensor, 50 | past_frames: th.Tensor = None 51 | ) -> th.Tensor: 52 | while len(curr_frame.shape) < 5: 53 | curr_frame = curr_frame.unsqueeze(0) 54 | assert curr_frame.shape == (1, 1, 3) + self.resolution, "Found shape {}".format(curr_frame.shape) 55 | curr_frame_feats = self.model.forward_image_features(curr_frame.to(self.device)) # 1 x 1 x 512 56 | 57 | if past_frames is None: 58 | past_frames = th.zeros((15, curr_frame_feats.shape[-1])) 59 | past_frames = past_frames.to(self.device) 60 | 61 | while len(past_frames.shape) < 3: 62 | past_frames = past_frames.unsqueeze(0) 63 | assert past_frames.shape == (1, 15, curr_frame_feats.shape[-1]), "Found shape {}".format(past_frames.shape) 64 | 65 | return th.cat((past_frames, curr_frame_feats), dim=1) 66 | 67 | def _get_video_feats( 68 | self, 69 | image_feats: th.Tensor 70 | ) -> th.Tensor: 71 | return self.model.forward_video_features(image_feats.to(self.device)) # 1 x 512 72 | 73 | def _get_text_feats( 74 | self, 75 | prompts: str 76 | ) -> th.Tensor: 77 | text_feats = self.model.encode_text(prompts) # P x 512 78 | assert len(text_feats.shape) == 2 and text_feats.shape[0] == len(prompts), "Found shape {}".format(text_feats.shape) 79 | return text_feats 80 | 81 | def get_logits( 82 | self, 83 | obs: Dict, # 3 x 160 x 256 84 | prompts: List[str], 85 | state: Tuple[th.Tensor, th.Tensor] = None # history x 512 86 | ) -> Tuple[th.Tensor, Tuple[th.Tensor, th.Tensor]]: 87 | 88 | curr_frame = self._get_curr_frame(obs) 89 | past_frames, text_feats = state 90 | 91 | with th.no_grad(): 92 | if text_feats is None: 93 | text_feats = self._get_text_feats(prompts) 94 | 95 | image_feats = self._get_image_feats(curr_frame, past_frames) 96 | video_feats = self._get_video_feats(image_feats) 97 | logits = self.model.forward_reward_head(video_feats.to(self.device), text_tokens=text_feats.to(self.device))[0][0] # P 98 | 99 | return logits, (image_feats[0, 1:].cpu(), text_feats.cpu()) 100 | 101 | def get_reward( 102 | self, 103 | obs: Dict, 104 | prompt: str, 105 | neg_prompts: List[str], 106 | state: Tuple[th.Tensor, th.Tensor] = None # history x 512 107 | ) -> Tuple[float, Tuple[th.Tensor, th.Tensor]]: 108 | logits, state = self.get_logits( 109 | obs, 110 | [prompt] + neg_prompts, 111 | state 112 | ) 113 | reward = self._get_reward_from_logits(logits) 114 | 115 | return reward, state 116 | 117 | def get_rewards( 118 | self, 119 | obs: Dict, 120 | prompts: List[str], 121 | neg_prompts: List[str], 122 | state: Tuple[th.Tensor, th.Tensor] = None # history x 512 123 | ) -> Tuple[List[float], Tuple[th.Tensor, th.Tensor]]: 124 | logits, state = self.get_logits( 125 | obs, 126 | prompts + neg_prompts, 127 | state 128 | ) 129 | rewards = [] 130 | for i in range(len(prompts)): 131 | rewards.append(self._get_reward_from_logits(th.cat(( 132 | logits[i:i+1], 133 | logits[len(prompts):] 134 | )))) 135 | return rewards, state 136 | -------------------------------------------------------------------------------- /techtree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import argparse 4 | import os 5 | import sys 6 | from datetime import datetime 7 | import shutil 8 | from copy import deepcopy 9 | import torch as th 10 | from stable_baselines3.common.utils import obs_as_tensor 11 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv 12 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 13 | from stable_baselines3.common.env_util import make_vec_env 14 | from stable_baselines3.common.save_util import load_from_zip_file 15 | 16 | from sb3_vpt.policy import VPTPolicy 17 | from tasks import make, get_specs 18 | 19 | 20 | def explore_techtree(env, policy, max_explore_steps=1e6, device="cuda", allow_sample_untrained=True): 21 | last_episode_starts = np.ones((env.num_envs,), dtype=bool) 22 | last_task_id = np.zeros((env.num_envs,), dtype=np.int16) 23 | last_vpt_states = policy.initial_state(env.num_envs) 24 | vpt_states = deepcopy(last_vpt_states) 25 | last_obs = env.reset() 26 | 27 | num_timesteps = 0 28 | while num_timesteps < max_explore_steps: 29 | 30 | with th.no_grad(): 31 | # Convert to pytorch tensor or to TensorDict 32 | obs_tensor = obs_as_tensor(last_obs, device) 33 | episode_starts = th.tensor(last_episode_starts).float().to(device) 34 | actions, _, _, vpt_states = policy.forward(obs_tensor, vpt_states, episode_starts, last_task_id) 35 | 36 | new_obs, _, dones, infos = env.step(actions.cpu().numpy()) 37 | num_timesteps = max(info["timestep"] for info in infos) 38 | 39 | curr_task_id = np.array([ 40 | info["subgoal"] if "subgoal" in info else 0 41 | for info in infos 42 | ], dtype=np.int16) 43 | 44 | if any(info["early_stop"] for info in infos): 45 | return "success" 46 | 47 | if not allow_sample_untrained and any("untrained" in info and info["untrained"] for info in infos): 48 | return [info["untrained"] for info in infos if "untrained" in info and info["untrained"]][0] 49 | 50 | last_obs = new_obs 51 | last_episode_starts = dones 52 | last_vpt_states = vpt_states 53 | last_task_id = curr_task_id 54 | 55 | return "done" 56 | 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument("--name", type=str, default="test", 61 | help="Name of the experiment, will be used to create a results directory.") 62 | parser.add_argument("--config", type=str, default="wooden_pickaxe", 63 | help="Name of task. Will check tasks/task_specs for specified name.") 64 | parser.add_argument("--model", type=str, default="models/3x.model", 65 | help="Path to file that stores model parameters for the policy.") 66 | parser.add_argument("--weights", type=str, default="weights/bc-house-3x.weights", 67 | help="Path to the file that stores initial model weights for the policy.") 68 | parser.add_argument("--load", type=str, default="", 69 | help="Path to a zip filed to load from, saved by a previous run.") 70 | parser.add_argument("--results_dir", type=str, default="./results", 71 | help="Path to results dir.") 72 | parser.add_argument("--steps", type=int, default=1000000, 73 | help="Total number of learner environement steps before learning stops.") 74 | parser.add_argument("--num_envs", type=int, default=4, 75 | help="Number of environment instances to run. Set to 0 to run 1 instance in the learner thread.") 76 | parser.add_argument("--cpu", action="store_true", 77 | help="Use cpus over gpus.") 78 | args = parser.parse_args() 79 | 80 | _, task_specs, _ = get_specs(args.config) 81 | 82 | log_dir = os.path.join(args.results_dir, args.name + "_" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) 83 | os.makedirs(log_dir) 84 | 85 | env = make_vec_env( 86 | lambda task="", kwargs=dict(): make(task, **kwargs), 87 | n_envs=max(1, args.num_envs), 88 | vec_env_cls=SubprocVecEnv if args.num_envs > 0 else DummyVecEnv, 89 | env_kwargs=dict( 90 | task=args.config, 91 | kwargs=dict( 92 | log_dir=log_dir 93 | ) 94 | ) 95 | ) 96 | 97 | if args.load: 98 | shutil.copyfile(args.load, os.path.join(log_dir, "techtree.json")) 99 | 100 | agent_parameters = pickle.load(open(args.model, "rb")) 101 | policy_kwargs = agent_parameters["model"]["args"]["net"]["args"] 102 | policy_kwargs["transformer_adapters"] = True 103 | policy_kwargs["n_adapters"] = len(task_specs["techtree_specs"]["tasks"]) if "tasks" in task_specs["techtree_specs"] \ 104 | else task_specs["techtree_specs"].pop("max_tasks", 16) 105 | pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"] 106 | 107 | device = "cpu" if args.cpu else "cuda" 108 | 109 | policy = VPTPolicy( 110 | env.observation_space, 111 | env.action_space, 112 | lambda x: 0, 113 | policy_kwargs=policy_kwargs, 114 | pi_head_kwargs=pi_head_kwargs, 115 | weights_path=args.weights 116 | ).to(device) 117 | if "tasks" in task_specs["techtree_specs"]: 118 | for task_id, task_weights in enumerate(task_specs["techtree_specs"]["tasks"].values()): 119 | if not task_weights: 120 | continue 121 | _, params, _ = load_from_zip_file(task_weights, device=device) 122 | for n, x in policy.model.named_modules(): 123 | if "img_process" not in n and n.split(".")[-1] == "adapter": 124 | x.task_adapters[task_id].load_state_dict( 125 | {".".join(k.split(".")[2:]): v for k, v in params["policy.model." + n].items()} 126 | ) 127 | policy.requires_grad_(False) 128 | policy.set_training_mode(False) 129 | 130 | explore_techtree(env, policy, args.steps, device) 131 | -------------------------------------------------------------------------------- /tech/technode.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple 2 | from copy import deepcopy 3 | import json 4 | 5 | 6 | class TechNode: 7 | def __init__(self, name: str, collectable: bool = False, recipe: List[Tuple["TechNode", int]] = [], 8 | tool: "TechNode" = None, table: bool = False, furnace: bool = False, timestep: int = 0, iteration: int = 0): 9 | self.name = name 10 | self.collectable = collectable 11 | self.recipe = recipe 12 | self.tool = tool 13 | self.table = table 14 | self.furnace = furnace 15 | self.timestep = timestep 16 | self.iteration = iteration 17 | assert self not in self.get_subnodes(), "Found cycle at node " + name 18 | 19 | def get_depth(self) -> List["TechNode"]: 20 | depths = [0] 21 | depths += [x.get_depth() for x, _ in self.recipe] 22 | if self.table: 23 | depths.append(3) 24 | if self.furnace: 25 | depths.append(6) 26 | if self.tool is not None: 27 | depths.append(self.tool.get_depth()) 28 | return 1 + max(depths) 29 | 30 | def get_children(self) -> List["TechNode"]: 31 | return [x for x, _ in self.recipe] 32 | 33 | def get_subnodes(self) -> List["TechNode"]: 34 | nodes = set() 35 | for child in self.get_children(): 36 | nodes.update([child] + child.get_subnodes()) 37 | return list(nodes) 38 | 39 | def get_requirements(self) -> List[str]: 40 | requirements = set([x.name for x, _ in self.recipe]) 41 | if self.tool is not None: 42 | requirements.add(self.tool.name) 43 | if self.table is not None: 44 | requirements.add("crafting_table") 45 | if self.furnace is not None: 46 | requirements.add("furnace") 47 | return list(requirements) 48 | 49 | def update_children_info(self, nodes: List["TechNode"]) -> None: 50 | new_recipe = [] 51 | for n, q in self.recipe: 52 | new_node = [x for x in nodes if x.name == n.name] 53 | if len(new_node) > 0: 54 | new_recipe.append((new_node[0], q)) 55 | else: 56 | new_recipe.append((n, q)) 57 | self.recipe = new_recipe 58 | 59 | if self.tool is not None and self.tool in nodes: 60 | self.tool = [x for x in nodes if x == self.tool][0] 61 | 62 | for n, _ in self.recipe: 63 | n.update_children_info(nodes) 64 | 65 | def get_ingredients(self, inventory: Dict[str, int] = None) -> Dict["TechNode", int]: 66 | if inventory is None: 67 | inventory = dict() 68 | return self._get_ingredients(deepcopy(inventory))[0] 69 | 70 | def get_craft_order(self, inventory: Dict[str, int] = None) -> List["TechNode"]: 71 | if inventory is None: 72 | inventory = dict() 73 | return self._get_craft_order(deepcopy(inventory))[0] 74 | 75 | def purge(self, name): 76 | self.recipe = [(n, q) for n, q in self.recipe if n.name != name] 77 | for n, _ in self.recipe: 78 | n.purge(name) 79 | if self.tool is not None and self.tool.name == name: 80 | self.tool = None 81 | elif self.tool is not None: 82 | self.tool.purge(name) 83 | 84 | def to_json(self) -> None: 85 | return dict( 86 | name=self.name, 87 | collectable=self.collectable, 88 | tool=self.tool.to_json() if self.tool is not None else "", 89 | table=self.table, 90 | furnace=self.furnace, 91 | timestep=self.timestep, 92 | iteration=self.iteration, 93 | recipe=[[x[0].to_json(), x[1]] for x in self.recipe] 94 | ) 95 | 96 | @staticmethod 97 | def from_json(info: dict) -> "TechNode": 98 | return TechNode( 99 | info["name"], 100 | collectable=info["collectable"] if "collectable" in info else False, 101 | recipe=[(TechNode.from_json(x[0]), x[1]) for x in info["recipe"]] if "recipe" in info else [], 102 | tool=TechNode.from_json(info["tool"]) if "tool" in info and info["tool"] != "" else None, 103 | table=info["table"] if "table" in info else False, 104 | furnace=info["furnace"] if "furnace" in info else False, 105 | timestep=info["timestep"] if "timestep" in info else 0, 106 | iteration=info["iteration"] if "iteration" in info else 0 107 | ) 108 | 109 | def _get_ingredients(self, inventory: Dict[str, int]) -> Tuple[Dict["TechNode", int], Dict[str, int]]: 110 | ingredients = dict() 111 | if self.collectable: 112 | return {self: 1}, inventory 113 | for node, quantity in self.recipe: 114 | if node.name in inventory: 115 | gathered = min(inventory[node.name], quantity) 116 | quantity -= gathered 117 | inventory[node.name] -= gathered 118 | for _ in range(quantity): 119 | node_ingredients, inventory = node._get_ingredients(inventory) 120 | for n, q in node_ingredients.items(): 121 | if n not in ingredients: 122 | ingredients[n] = 0 123 | ingredients[n] += q 124 | return ingredients, inventory 125 | 126 | def _get_craft_order(self, inventory: Dict[str, int]) -> Tuple[List["TechNode"], Dict[str, int]]: 127 | to_craft = [] 128 | for node, quantity in self.recipe: 129 | if node.name in inventory: 130 | gathered = min(inventory[node.name], quantity) 131 | quantity -= gathered 132 | inventory[node.name] -= gathered 133 | for _ in range(quantity): 134 | next_crafts, inventory = node._get_craft_order(inventory) 135 | to_craft += next_crafts 136 | if not self.collectable: 137 | to_craft += [self] 138 | return to_craft, inventory 139 | 140 | def __eq__(self, __o: object) -> bool: 141 | return hasattr(__o, "name") and __o.name == self.name 142 | 143 | def __hash__(self): 144 | return hash(self.name) 145 | 146 | def __str__(self) -> str: 147 | return json.dumps(self.to_json(), indent=4) 148 | -------------------------------------------------------------------------------- /sb3_vpt/logging.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from csv import DictWriter 4 | from datetime import datetime 5 | import psutil 6 | import json 7 | 8 | from stable_baselines3.common.callbacks import BaseCallback 9 | 10 | 11 | class LoggingCallback(BaseCallback): 12 | def __init__(self, model, log_dir, **kwargs): 13 | super().__init__(**kwargs) 14 | self.log_dir = log_dir 15 | self.model = model 16 | self.save_freq = 100 17 | 18 | self.log_out = [] 19 | self.iteration = -1 20 | self.num_dones = 0 21 | self.total_steps = 0 22 | self.rollout_steps = 0 23 | self.cum_rewards = None 24 | self.successes = None 25 | self.success_count = 0 26 | self.rollout_start_time = None 27 | self.update_start_time = None 28 | self.iter_rewards = [] 29 | self.subgoals = [] 30 | self.clip_scores = None 31 | 32 | def _on_training_start(self) -> None: 33 | self.cum_rewards = np.zeros(self.training_env.num_envs) 34 | self.successes = np.zeros(self.training_env.num_envs).astype(bool) 35 | self.clip_scores = [[] for _ in range(self.training_env.num_envs)] 36 | 37 | def _on_step(self) -> bool: 38 | self.cum_rewards += self.locals["rewards"] 39 | eps_rewards = self.cum_rewards[np.where(self.locals["dones"])].tolist() 40 | for r in eps_rewards: 41 | self.logger.record_mean("custom/reward", r) 42 | self.iter_rewards += eps_rewards 43 | self.cum_rewards *= (1 - self.locals["dones"]) 44 | if "subgoal" in self.locals["infos"][0]: 45 | self.subgoals += [info["subgoal"] for i, info in enumerate(self.locals["infos"]) if self.locals["dones"][i]] 46 | 47 | success = np.array([x["success"] if "success" in x else False for x in self.locals["infos"]]) 48 | self.successes = np.bitwise_or(success, self.successes) 49 | self.success_count += np.sum(np.bitwise_and(self.successes, self.locals["dones"])) 50 | self.successes = np.where(self.locals["dones"], 0, self.successes) 51 | 52 | # Log exploration 53 | if self.clip_scores is not None and "clip_scores" in self.locals["infos"][0]: 54 | for i, info in enumerate(self.locals["infos"]): 55 | self.clip_scores[i].append(info["clip_scores"]) 56 | if self.locals["dones"][i]: 57 | with open(os.path.join(self.log_dir, "clip_scores.txt"), "a") as f: 58 | f.write(json.dumps(self.clip_scores[i]) + "\n") 59 | self.clip_scores[i] = [] 60 | 61 | self.num_dones += np.sum(self.locals["dones"]) 62 | self.rollout_steps += self.training_env.num_envs 63 | self.total_steps += self.training_env.num_envs 64 | if "craft_steps" in self.locals["infos"][0]: 65 | self.total_steps += sum(info["craft_steps"] for info in self.locals["infos"]) 66 | 67 | return True 68 | 69 | def _on_rollout_start(self): 70 | if self.update_start_time is not None: 71 | print("Finished updated in", datetime.now() - self.update_start_time) 72 | self.logger.record("custom/update_secs", (datetime.now() - self.update_start_time).total_seconds()) 73 | print() 74 | 75 | self.iteration += 1 76 | if self.iteration > 0: 77 | self.log_out.append(dict( 78 | timesteps=self.total_steps, 79 | rollout_secs=self.update_start_time - self.rollout_start_time, 80 | update_secs=datetime.now() - self.update_start_time, 81 | dones=self.num_dones, 82 | success=self.success_count/self.num_dones if self.num_dones > 0 else np.nan, 83 | reward=np.mean(self.iter_rewards) if self.num_dones > 0 else np.nan, 84 | max_reward=np.amax(self.iter_rewards) if self.num_dones > 0 else np.nan, 85 | subgoals=np.mean(self.subgoals) if len(self.subgoals) > 0 else np.nan, 86 | max_subgoals=np.amax(self.subgoals) if len(self.subgoals) > 0 else np.nan, 87 | memory=psutil.virtual_memory()[3]/1e9 88 | )) 89 | with open(os.path.join(self.log_dir, "stats.csv"), "w") as f: 90 | writer = DictWriter(f, fieldnames=list(self.log_out[0].keys())) 91 | writer.writeheader() 92 | writer.writerows(self.log_out) 93 | if self.iteration % self.save_freq == 0: 94 | self.model.save(os.path.join(self.log_dir, "checkpoints", "timestep_{}".format(self.num_timesteps))) 95 | 96 | self.num_dones = 0 97 | self.rollout_steps = 0 98 | self.rollout_start_time = datetime.now() 99 | self.iter_rewards = [] 100 | self.subgoals = [] 101 | self.success_count = 0 102 | print("Starting rollout") 103 | 104 | def _on_rollout_end(self): 105 | self.update_start_time = datetime.now() 106 | print("Finished rollout in", self.update_start_time - self.rollout_start_time) 107 | print("\tMax reward:", np.amax(self.iter_rewards) if self.num_dones > 0 else np.nan) 108 | print("\tLast rewards:", np.mean(self.iter_rewards)) 109 | print("\tNum dones:", self.num_dones) 110 | print("\tMemory:", psutil.virtual_memory()[3]/1e9) 111 | print("\tFPS:", self.rollout_steps / (self.update_start_time - self.rollout_start_time).total_seconds()) 112 | 113 | self.logger.record("custom/rollout_secs", (self.update_start_time - self.rollout_start_time).total_seconds()) 114 | self.logger.record("custom/completed_episodes", self.num_dones) 115 | self.logger.record("custom/FPS", self.rollout_steps / (self.update_start_time - self.rollout_start_time).total_seconds()) 116 | if len(self.iter_rewards) > 0: 117 | self.logger.record("custom/max_reward", np.amax(self.iter_rewards)) 118 | if len(self.subgoals) > 0: 119 | self.logger.record("custom/max_subgoals", np.amax(self.subgoals)) 120 | self.logger.record("custom/subgoals", np.mean(self.subgoals)) 121 | self.logger.record("custom/memory", psutil.virtual_memory()[3]/1e9) 122 | 123 | if self.num_dones > 0: 124 | print("\tSuccesses: {}/{}={}".format(self.success_count, self.num_dones, self.success_count/self.num_dones)) 125 | self.logger.record("custom/success", self.success_count/self.num_dones) 126 | 127 | print("Starting update") -------------------------------------------------------------------------------- /subtask.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import os 4 | from datetime import datetime 5 | import sys 6 | import shutil 7 | 8 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv 9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 10 | from stable_baselines3.common.env_util import make_vec_env 11 | 12 | from sb3_vpt.algorithm import VPTPPO 13 | from sb3_vpt.policy import VPTPolicy 14 | from sb3_vpt.logging import LoggingCallback 15 | from tasks import make, get_specs 16 | 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--name", type=str, default="test", 21 | help="Name of the experiment, will be used to create a results directory.") 22 | parser.add_argument("--config", type=str, default="base_task", 23 | help="Minedojo task to run. Should be a minedojo task_id or exist in tasks/task_specs.yaml") 24 | parser.add_argument("--target_item", type=str, default="log", 25 | help="Item to use if using base_task.") 26 | parser.add_argument("--model", type=str, default="models/3x.model", 27 | help="Path to file that stores model parameters for the policy.") 28 | parser.add_argument("--weights", type=str, default="weights/bc-house-3x.weights", 29 | help="Path to the file that stores initial model weights for the policy.") 30 | parser.add_argument("--load", type=str, default="", 31 | help="Path to a zip filed to load from, saved by a previous run.") 32 | parser.add_argument("--results_dir", type=str, default="./results", 33 | help="Path to results dir.") 34 | parser.add_argument("--steps", type=int, default=10000000, 35 | help="Total number of learner environement steps before learning stops.") 36 | parser.add_argument("--steps_per_iter", type=int, default=500, 37 | help="Number of steps per environment each iteration.") 38 | parser.add_argument("--batch_size", type=int, default=40, 39 | help="Batch size for learning.") 40 | parser.add_argument("--n_epochs", type=int, default=5, 41 | help="Number of PPO epochs every iteration.") 42 | parser.add_argument("--num_envs", type=int, default=4, 43 | help="Number of environment instances to run. Set to 0 to run 1 instance in the learner thread.") 44 | parser.add_argument("--lr", type=float, default=1e-4, 45 | help="Learning rate.") 46 | parser.add_argument("--gamma", type=float, default=.999, 47 | help="Discount factor.") 48 | parser.add_argument("--kl_coef", type=float, default=.1, 49 | help="Initial loss coefficient for VPT KL loss.") 50 | parser.add_argument("--kl_decay", type=float, default=.999, 51 | help="How much to decay KL coefficient each iteration.") 52 | parser.add_argument("--adapter_factor", type=float, default=16, 53 | help="What reduction factor to use for adapters.") 54 | parser.add_argument("--cpu", action="store_true", 55 | help="Use cpus over gpus.") 56 | parser.add_argument("--update_norms", action="store_true", 57 | help="Update the layer norms of the network.") 58 | parser.add_argument("--final_layer", action="store_true", 59 | help="Update the layer immediately before the heads.") 60 | parser.add_argument("--policy_head", action="store_true", 61 | help="Update the policy head.") 62 | parser.add_argument("--no_transformer_adapters", action="store_true", 63 | help="Trains adapters in the transformer.") 64 | parser.add_argument("--finetune_full", action="store_true", 65 | help="Finetune the entire network.") 66 | parser.add_argument("--finetune_transformer", action="store_true", 67 | help="Finetune the transformer and heads.") 68 | args = parser.parse_args() 69 | 70 | _, task_specs, _ = get_specs(args.config) 71 | vars(args).update(**task_specs) 72 | 73 | log_dir = os.path.join(args.results_dir, args.name + "_" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) 74 | os.makedirs(os.path.join(log_dir, "checkpoints")) 75 | 76 | # sys.stderr = open(os.path.join(log_dir, "error.txt"), "w") 77 | 78 | env = make_vec_env( 79 | lambda task="", kwargs=dict(): make(task, **kwargs), 80 | n_envs=max(1, args.num_envs), 81 | vec_env_cls=SubprocVecEnv if args.num_envs > 0 else DummyVecEnv, 82 | env_kwargs=dict( 83 | task=args.config, 84 | kwargs=dict( 85 | log_dir=log_dir, 86 | target_item=args.target_item 87 | ) 88 | ) 89 | ) 90 | 91 | if args.load: 92 | model = VPTPPO.load(args.load, env) 93 | prev_log_dir = "/".join(args.load.split("/")[:-2]) 94 | if "techtree_specs" in task_specs: 95 | shutil.copyfile(os.path.join(prev_log_dir, "techtree.json"), os.path.join(log_dir, "techtree.json")) 96 | else: 97 | 98 | agent_parameters = pickle.load(open(args.model, "rb")) 99 | policy_kwargs = agent_parameters["model"]["args"]["net"]["args"] 100 | policy_kwargs["transformer_adapters"] = not args.no_transformer_adapters 101 | policy_kwargs["adapter_factor"] = args.adapter_factor 102 | policy_kwargs["n_adapters"] = 1 103 | pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"] 104 | pi_head_kwargs["adapter_factor"] = args.adapter_factor 105 | pi_head_kwargs["n_adapters"] = 1 106 | 107 | model = VPTPPO( 108 | VPTPolicy, 109 | env, 110 | n_steps=args.steps_per_iter, 111 | batch_size=args.batch_size, 112 | n_epochs=args.n_epochs, 113 | device="cpu" if args.cpu else "cuda", 114 | policy_kwargs=dict( 115 | policy_kwargs=policy_kwargs, 116 | pi_head_kwargs=pi_head_kwargs, 117 | weights_path=args.weights, 118 | ), 119 | tensorboard_log=os.path.join(log_dir, "tb"), 120 | learning_rate=args.lr, 121 | gamma=args.gamma, 122 | vf_coef=1, 123 | kl_coef=args.kl_coef, 124 | kl_decay=args.kl_decay, 125 | n_tasks=1, 126 | ) 127 | model.learn( 128 | args.steps, 129 | callback=LoggingCallback(model, log_dir) 130 | ) 131 | -------------------------------------------------------------------------------- /tasks/minedojo/wrappers.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | from minedojo.sim.wrappers.fast_reset import FastResetWrapper 4 | from minedojo.sim.mc_meta.mc import ALL_ITEMS, ALL_PERSONAL_CRAFTING_ITEMS, ALL_CRAFTING_TABLE_ITEMS, ALL_SMELTING_ITEMS,\ 5 | CRAFTING_RECIPES_BY_OUTPUT, SMELTING_RECIPES_BY_OUTPUT 6 | 7 | from tasks.base import * 8 | 9 | 10 | def name_match(target_name, obs_name): 11 | return target_name.replace(" ", "_") == obs_name.replace(" ", "_") 12 | 13 | 14 | # Fast reset wrapper saves time but doesn't replace blocks 15 | # Occasionally doing a hard reset should prevent state shift 16 | class MinedojoSemifastResetWrapper(FastResetWrapper): 17 | 18 | def __init__(self, *args, reset_freq=100, **kwargs): 19 | super().__init__(*args, **kwargs) 20 | self.reset_freq = reset_freq 21 | self.reset_count = 0 22 | 23 | def reset(self): 24 | if self.reset_count < self.reset_freq: 25 | self.reset_count += 1 26 | return super().reset() 27 | else: 28 | self.reset_count = 0 29 | return self.env.reset() 30 | 31 | 32 | class MinedojoClipReward(ClipReward): 33 | @staticmethod 34 | def _get_curr_frame(obs): 35 | curr_frame = obs["rgb"].copy() 36 | return th.from_numpy(curr_frame) 37 | 38 | @staticmethod 39 | def get_resolution(): 40 | return (160, 256) 41 | 42 | 43 | class MinedojoTechTreeWrapper(TechTreeWrapper): 44 | def _get_inventory(self, obs): 45 | inventory = { 46 | name: sum( 47 | int(obs["inventory"]["quantity"][i]) 48 | for i, n in enumerate(obs["inventory"]["name"]) if name_match(name, n) 49 | ) 50 | for name in self._get_all_items() 51 | } 52 | inventory.pop("dirt", None) 53 | inventory.pop("air", None) 54 | inventory["log"] += inventory.pop("log2", 0) 55 | return inventory 56 | 57 | def _get_all_items(self): 58 | return list(set(ALL_ITEMS) - set(["dirt", "log2", "air"])) 59 | 60 | def _get_craftables(self): 61 | return ALL_CRAFTING_TABLE_ITEMS + ALL_SMELTING_ITEMS 62 | 63 | def _get_noop_action(self): 64 | return { 65 | handler.to_string(): handler.space.no_op() 66 | for handler in self.unwrapped._sim_spec.actionables 67 | } 68 | 69 | def _get_craft_action(self, item, crafting_table=False, furnace=False, valid=None): 70 | if valid is None: 71 | valid = self._get_all_items() 72 | action = { 73 | handler.to_string(): handler.space.no_op() 74 | for handler in self.unwrapped._sim_spec.actionables 75 | } 76 | if (item in ALL_PERSONAL_CRAFTING_ITEMS or crafting_table and item in ALL_CRAFTING_TABLE_ITEMS) and \ 77 | all(n in valid and self.curr_inventory[n] >= q for n, q in CRAFTING_RECIPES_BY_OUTPUT[item][0]["ingredients"].items()) or \ 78 | furnace and item in ALL_SMELTING_ITEMS and \ 79 | all(n in valid and self.curr_inventory[n] >= q for n, q in SMELTING_RECIPES_BY_OUTPUT[item][0]["ingredients"].items()): 80 | action["craft"] = item 81 | action["craft_with_table"] = item 82 | action["smelt"] = item 83 | else: 84 | action = None 85 | return action 86 | 87 | def _equip_item(self, obs, action, item): 88 | if obs is None: 89 | return action 90 | hotbar_items = [x.replace(" ", "_") for x in obs["inventory"]["name"][:9].tolist()] 91 | if item in hotbar_items: 92 | hotbar_names = ["hotbar." + str(x) for x in range(1, 10)] 93 | for handler in self.unwrapped._sim_spec.actionables: 94 | if handler.to_string() in ["drop", "swap_slot", "pickItem"] + hotbar_names: 95 | action[handler.to_string()] = handler.space.no_op() 96 | action[hotbar_names[hotbar_items.index(item)]] = np.array(1) 97 | return action 98 | 99 | 100 | class MinedojoRewardWrapper(RewardWrapper): 101 | @staticmethod 102 | def _get_item_count(obs, item): 103 | return sum(quantity for name, quantity in zip(obs["inventory"]["name"], obs["inventory"]["quantity"]) if name_match(item, name)) 104 | 105 | 106 | class MinedojoSuccessWrapper(SuccessWrapper): 107 | @staticmethod 108 | def _check_item_condition(condition_info, obs): 109 | return sum(quantity for name, quantity in zip(obs["inventory"]["name"], obs["inventory"]["quantity"]) 110 | if name_match(condition_info["type"], name)) >= condition_info["quantity"] 111 | 112 | @staticmethod 113 | def _check_blocks_condition(condition_info, obs): 114 | target = np.array(condition_info) 115 | voxels = obs["voxels"]["block_name"].transpose(1,0,2) 116 | for y in range(voxels.shape[0] - target.shape[0]): 117 | for x in range(voxels.shape[1] - target.shape[1]): 118 | for z in range(voxels.shape[2] - target.shape[2]): 119 | if np.all(voxels[y:y+target.shape[0], 120 | x:x+target.shape[1], 121 | z:z+target.shape[2]] == target): 122 | return True 123 | return False 124 | 125 | 126 | class MinedojoTerminalWrapper(TerminalWrapper): 127 | @staticmethod 128 | def _check_item_condition(condition_info, obs): 129 | return sum(quantity for name, quantity in zip(obs["inventory"]["name"], obs["inventory"]["quantity"]) 130 | if name_match(condition_info["type"], name)) >= condition_info["quantity"] 131 | 132 | @staticmethod 133 | def _check_blocks_condition(condition_info, obs): 134 | target = np.array(condition_info) 135 | voxels = obs["voxels"]["block_name"].transpose(1,0,2) 136 | for y in range(voxels.shape[0] - target.shape[0]): 137 | for x in range(voxels.shape[1] - target.shape[1]): 138 | for z in range(voxels.shape[2] - target.shape[2]): 139 | if np.all(voxels[y:y+target.shape[0], 140 | x:x+target.shape[1], 141 | z:z+target.shape[2]] == target): 142 | return True 143 | return False 144 | 145 | @staticmethod 146 | def _check_death_condition(condition_info, obs): 147 | return obs["life_stats"]["life"].item() == 0 148 | 149 | 150 | class MinedojoVPTWrapper(VPTWrapper): 151 | 152 | def _filter_actions(self, actions): 153 | filtered_actions = { 154 | handler.to_string(): actions[handler.to_string()] 155 | if handler.to_string() in actions else handler.space.no_op() 156 | for handler in self.unwrapped._sim_spec.actionables # This comes from MinedojoSim.SimSpec 157 | } # Filter malmo actions by what current minedojo task has enabled 158 | return filtered_actions 159 | 160 | def _get_curr_frame(self, obs): 161 | return np.transpose(obs["rgb"], (1, 2, 0)) 162 | -------------------------------------------------------------------------------- /sb3_vpt/policy.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, List, Optional, Tuple, Type, Union 2 | from gym3.types import DictType 3 | import gym 4 | import torch as th 5 | from torch import nn 6 | from itertools import chain 7 | 8 | from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy 9 | 10 | from sb3_vpt.types import VPTStates 11 | from VPT.lib.policy import MinecraftAgentPolicy 12 | from VPT.lib.action_mapping import CameraHierarchicalMapping 13 | from VPT.lib.tree_util import tree_map 14 | 15 | 16 | class VPTPolicy(RecurrentActorCriticPolicy): 17 | def __init__( 18 | self, 19 | observation_space: gym.spaces.Space, 20 | action_space: gym.spaces.Space, 21 | lr_schedule: Callable[[float], float], 22 | net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, 23 | activation_fn: Type[nn.Module] = nn.Tanh, 24 | *args, 25 | **kwargs\ 26 | ): 27 | policy_kwargs = kwargs.pop("policy_kwargs", dict()) 28 | pi_head_kwargs = kwargs.pop("pi_head_kwargs", dict()) 29 | weights_path = kwargs.pop("weights_path", None) 30 | vpt_action_space = DictType(**CameraHierarchicalMapping(n_camera_bins=11).get_action_space_update()) 31 | 32 | super().__init__( 33 | observation_space, 34 | action_space, 35 | lr_schedule, 36 | net_arch, 37 | activation_fn, 38 | *args, 39 | **kwargs, 40 | ) 41 | 42 | self.model = MinecraftAgentPolicy( 43 | policy_kwargs=policy_kwargs, 44 | pi_head_kwargs=pi_head_kwargs, 45 | action_space=vpt_action_space 46 | ) 47 | self.exploration_model = MinecraftAgentPolicy( 48 | policy_kwargs=policy_kwargs, 49 | pi_head_kwargs=pi_head_kwargs, 50 | action_space=vpt_action_space 51 | ) 52 | if weights_path: 53 | self.model.load_state_dict(th.load(weights_path), strict=False) 54 | self.exploration_model.load_state_dict(th.load(weights_path), strict=False) 55 | 56 | self.exploration_model.requires_grad_(False) 57 | 58 | self.model.requires_grad_(False) 59 | self.params = {} 60 | 61 | self.model.value_head.reset_parameters() 62 | self.model.value_head.requires_grad_(True) 63 | self.params["model.value_head"] = self.model.value_head.parameters() 64 | 65 | for n, x in self.model.named_modules(): 66 | if "img_process" not in n and n.split(".")[-1] == "adapter": 67 | x.requires_grad_(True) 68 | self.params["model." + n] = x.parameters() 69 | 70 | self.optimizer = self.optimizer_class( 71 | chain(*self.params.values()), 72 | lr=lr_schedule(1), 73 | **self.optimizer_kwargs 74 | ) 75 | 76 | def get_param_keys(self) -> List[str]: 77 | return list(self.params.keys()) 78 | 79 | @staticmethod 80 | def _vpt_states_to_sb3(states): 81 | st = ([], [], []) 82 | for block_st in states: 83 | if block_st[0] is None: 84 | st[0].append(th.full_like(block_st[1][0], -1)[:, :, 0]) 85 | else: 86 | assert block_st[0].shape[1] == 1 87 | st[0].append(block_st[0][:, 0]) 88 | st[1].append(block_st[1][0]) 89 | st[2].append(block_st[1][1]) 90 | st = tuple([ 91 | th.cat([blk.unsqueeze(0) for blk in state], dim=0) 92 | for state in st 93 | ]) 94 | return VPTStates(*st) 95 | 96 | @staticmethod 97 | def _sb3_states_to_vpt(states): 98 | return tuple([ 99 | ( 100 | None if th.all(states[0][i] == -1) else \ 101 | states[0][i].unsqueeze(1).bool() if len(states[0][i].shape) == 2 else \ 102 | states[0][i].bool(), 103 | (states[1][i], states[2][i]) 104 | ) 105 | for i in range(states[0].shape[0]) 106 | ]) 107 | 108 | def initial_state(self, batch_size): 109 | return self._vpt_states_to_sb3(self.model.initial_state(batch_size)) 110 | 111 | def forward(self, 112 | obs: th.Tensor, # batch x H x W x C 113 | in_states: VPTStates, # n_blocks x 1 x buffer, n_blocks x 1 x buffer x hidden 114 | episode_starts: th.Tensor, # batch 115 | task_id: th.Tensor, # batch 116 | deterministic: bool = False 117 | ) -> Tuple[Dict[str, th.Tensor], th.Tensor, th.Tensor, VPTStates, Dict[str, th.Tensor]]: 118 | 119 | # pd: dict: batch x 1 x 1 x 121, batch x 1 x 1 x 8641 120 | # vpred: batch x 1 x 1 121 | (pd, vpred, _), state_out = self.model( 122 | tree_map(lambda x: x.unsqueeze(1), {"img": obs}), 123 | episode_starts.unsqueeze(1).bool(), 124 | self._sb3_states_to_vpt(in_states), 125 | task_id 126 | ) 127 | 128 | ac = self.model.pi_head.sample(pd, deterministic=deterministic) # dict: batch x 1 x 1 129 | log_prob = self.model.pi_head.logprob(ac, pd)[:, 0] # batch 130 | vpred = vpred[:, 0, 0] # batch 131 | ac = th.cat([x[:, 0] for x in ac.values()], dim=1) # batch x 2 132 | 133 | return ac, vpred, log_prob, self._vpt_states_to_sb3(state_out) 134 | 135 | def predict_values( 136 | self, 137 | obs: th.Tensor, 138 | in_states: VPTStates, 139 | episode_starts: th.Tensor, 140 | task_id: th.Tensor 141 | ) -> th.Tensor: 142 | (_, vpred, _), _ = self.model( 143 | tree_map(lambda x: x.unsqueeze(1), {"img": obs}), 144 | episode_starts.unsqueeze(-1).bool(), 145 | self._sb3_states_to_vpt(in_states), 146 | task_id 147 | ) 148 | return vpred[:, 0, 0] # batch x 1 149 | 150 | def evaluate_actions( 151 | self, 152 | obs: th.Tensor, # n_seq * max_len x H x W x C 153 | actions: th.Tensor, # n_seq * max_len x 2 154 | in_states: VPTStates, # n_blocks x n_seq x buffer, n_blocks x n_seq x buffer x hidden 155 | episode_starts: th.Tensor, # n_seq * max_len 156 | task_id: th.Tensor, # n_seq * max_len 157 | ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: 158 | 159 | n_seq = in_states[0].shape[1] 160 | obs_sequence = obs.reshape((n_seq, -1) + obs.shape[-3:]) # n_seq x max_len x H x W x C 161 | max_len = obs_sequence.shape[1] 162 | starts_sequence = episode_starts.reshape((n_seq, max_len)) # n_seq x max_len 163 | seq_task = task_id.reshape((n_seq, max_len))[:, 0] # n_seq 164 | model_input = {"img": obs_sequence}, starts_sequence.bool(), self._sb3_states_to_vpt(in_states), seq_task 165 | 166 | # pd: dict: n_seq x max_len x 1 x 121, n_seq x max_len x 1 x 8641 167 | # vpred: n_seq x max_len x 1 168 | (pd, vpred, _), _ = self.model(*model_input) 169 | with th.no_grad(): 170 | (exploration_pd, _, _), _ = self.exploration_model(*model_input) 171 | 172 | actions_dict = { 173 | k: actions[:, i].reshape((n_seq, max_len, 1)) 174 | for i, k in enumerate(self.model.pi_head.keys()) 175 | } 176 | log_prob = self.model.pi_head.logprob(actions_dict, pd) # n_seq x max_len 177 | kl = self.model.get_kl_of_action_dists(pd, exploration_pd) # n_seq x max_len x 1 178 | 179 | return th.flatten(vpred), th.flatten(log_prob), th.flatten(kl) 180 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import os 4 | from gym import Wrapper 5 | from datetime import datetime 6 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv 7 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 8 | from stable_baselines3.common.env_util import make_vec_env 9 | 10 | from sb3_vpt.algorithm import VPTPPO 11 | from sb3_vpt.policy import VPTPolicy 12 | from tasks import make, get_specs 13 | from techtree import explore_techtree 14 | from tasks.base import * 15 | from tasks.minedojo import MinedojoSemifastResetWrapper 16 | 17 | 18 | class SetterWrapper(Wrapper): 19 | 20 | def recursive_set_attr(self, cls, attr, value): 21 | env = self.env 22 | while isinstance(env, Wrapper): 23 | if isinstance(env, cls): 24 | setattr(env, attr, value) 25 | return 26 | env = env.env 27 | if isinstance(env, cls): 28 | setattr(env, attr, value) 29 | 30 | def set_item_task(self, item): 31 | self.recursive_set_attr(MinedojoSemifastResetWrapper, "reset_freq", 5) 32 | self.recursive_set_attr(ClipWrapper, "prompt", ["collect " + item]) 33 | self.recursive_set_attr(RewardWrapper, "item_rewards", {item: {"reward":1}}) 34 | self.recursive_set_attr(SuccessWrapper, "all_conditions", {"item": {"type": item, "quantity": 1}}) 35 | self.recursive_set_attr(TerminalWrapper, "max_steps", 1000) 36 | self.recursive_set_attr(TechTreeWrapper, "is_techtree_active", False) 37 | 38 | def set_techtree_task(self): 39 | self.recursive_set_attr(MinedojoSemifastResetWrapper, "reset_freq", 0) 40 | self.recursive_set_attr(ClipWrapper, "prompt", []) 41 | self.recursive_set_attr(RewardWrapper, "item_rewards", {}) 42 | self.recursive_set_attr(SuccessWrapper, "all_conditions", {}) 43 | self.recursive_set_attr(TerminalWrapper, "max_steps", 10000) 44 | self.recursive_set_attr(TechTreeWrapper, "is_techtree_active", True) 45 | 46 | 47 | if __name__ == "__main__": 48 | 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--name", type=str, default="test", 51 | help="Name of the experiment, will be used to create a results directory.") 52 | parser.add_argument("--config", type=str, default="wooden_pickaxe", 53 | help="Name of task. Will check tasks/task_specs.yaml for specified name.") 54 | parser.add_argument("--model", type=str, default="models/3x.model", 55 | help="Path to file that stores model parameters for the policy.") 56 | parser.add_argument("--weights", type=str, default="weights/bc-house-3x.weights", 57 | help="Path to the file that stores initial model weights for the policy.") 58 | parser.add_argument("--results_dir", type=str, default="./results", 59 | help="Path to results dir.") 60 | parser.add_argument("--explore_steps", type=int, default=2000000, 61 | help="Number of environment steps each iteration until exploration early stops.") 62 | parser.add_argument("--steps_per_subtask", type=int, default=2000000, 63 | help="Number of environment steps to allow for each subtask.") 64 | parser.add_argument("--steps_per_iter", type=int, default=500, 65 | help="Number of steps per environment each iteration.") 66 | parser.add_argument("--batch_size", type=int, default=40, 67 | help="Batch size for learning.") 68 | parser.add_argument("--n_epochs", type=int, default=5, 69 | help="Number of PPO epochs every iteration.") 70 | parser.add_argument("--num_envs", type=int, default=4, 71 | help="Number of environment instances to run. Set to 0 to run 1 instance in the learner thread.") 72 | parser.add_argument("--lr", type=float, default=1e-4, 73 | help="Learning rate.") 74 | parser.add_argument("--gamma", type=float, default=.999, 75 | help="Discount factor.") 76 | parser.add_argument("--kl_coef", type=float, default=.1, 77 | help="Initial loss coefficient for VPT KL loss.") 78 | parser.add_argument("--kl_decay", type=float, default=.999, 79 | help="How much to decay KL coefficient each iteration.") 80 | parser.add_argument("--cpu", action="store_true", 81 | help="Use cpus over gpus.") 82 | args = parser.parse_args() 83 | 84 | _, task_specs, _ = get_specs(args.config) 85 | assert "tasks" not in task_specs["techtree_specs"], "Don't use this script with pretrained task adapters. Use techtree.py instead." 86 | 87 | # Prepare results dir 88 | log_dir = os.path.join(args.results_dir, args.name + "_" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) 89 | os.makedirs(log_dir) 90 | print("Created results directory", log_dir) 91 | 92 | # Get techtree exploration env 93 | env = make_vec_env( 94 | lambda task="", kwargs=dict(): SetterWrapper(make(task, **kwargs)), 95 | n_envs=max(1, args.num_envs), 96 | vec_env_cls=SubprocVecEnv if args.num_envs > 0 else DummyVecEnv, 97 | env_kwargs=dict( 98 | task=args.config, 99 | kwargs=dict( 100 | log_dir=log_dir 101 | ) 102 | ) 103 | ) 104 | 105 | # Prepare policy 106 | max_tasks = task_specs["techtree_specs"].pop("max_tasks", 16) 107 | agent_parameters = pickle.load(open(args.model, "rb")) 108 | policy_kwargs = agent_parameters["model"]["args"]["net"]["args"] 109 | policy_kwargs["transformer_adapters"] = True 110 | policy_kwargs["n_adapters"] = max_tasks 111 | pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"] 112 | 113 | device = "cpu" if args.cpu else "cuda" 114 | 115 | model = VPTPPO( 116 | VPTPolicy, 117 | env, 118 | n_steps=args.steps_per_iter, 119 | batch_size=args.batch_size, 120 | n_epochs=args.n_epochs, 121 | device=device, 122 | policy_kwargs=dict( 123 | policy_kwargs=policy_kwargs, 124 | pi_head_kwargs=pi_head_kwargs, 125 | weights_path=args.weights, 126 | ), 127 | tensorboard_log=os.path.join(log_dir, "tb"), 128 | learning_rate=args.lr, 129 | gamma=args.gamma, 130 | vf_coef=1, 131 | kl_coef=args.kl_coef, 132 | kl_decay=args.kl_decay, 133 | n_tasks=1, 134 | ) 135 | 136 | # Get tasks so far 137 | subtasks = [] 138 | print("Starting training...") 139 | 140 | # Start training 141 | while True: 142 | if len(subtasks) >= max_tasks: 143 | print("Ran out of adapters for subtasks...") 144 | break 145 | 146 | # Explore to discover new subtask nodes 147 | print("Exploring subtasks") 148 | model.env.env_method("set_techtree_task") 149 | next_task = explore_techtree(model.env, model.policy, allow_sample_untrained=False, max_explore_steps=args.explore_steps) 150 | 151 | # If a new subtask is discovered 152 | if next_task == "success": 153 | print(args.config, "task success") 154 | break 155 | elif next_task == "done": 156 | print(args.config, "task failed") 157 | break 158 | else: 159 | print("Found new subtask:", next_task) 160 | 161 | # Modify environment for training subtask 162 | model.env.env_method("set_item_task", next_task) 163 | 164 | # Finetune adapters with RL 165 | print("Beginning adapter finetuning for subtask", next_task, "with adapter set", len(subtasks)) 166 | model.reset() 167 | model.set_task_id(len(subtasks)) 168 | model.learn(args.steps_per_subtask) 169 | model.save(os.path.join(log_dir, "task{}_{}".format(len(subtasks), next_task))) 170 | print("Subtask finetuning finished, saving model to task{}_{}.zip".format(len(subtasks), next_task)) 171 | 172 | # Add new subtask 173 | subtasks.append(next_task) 174 | model.env.env_method("add_task", next_task) 175 | print("Finished adding new subtask", next_task) 176 | model.env.close() 177 | -------------------------------------------------------------------------------- /sb3_vpt/buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Generator, Optional, Tuple, Union 3 | 4 | import numpy as np 5 | import torch as th 6 | from gym import spaces 7 | from stable_baselines3.common.vec_env import VecNormalize 8 | from stable_baselines3.common.buffers import RolloutBuffer 9 | from sb3_contrib.common.recurrent.buffers import create_sequencers 10 | 11 | from sb3_vpt.types import VPTStates, VPTRolloutBufferSamples 12 | 13 | 14 | class VPTBuffer(RolloutBuffer): 15 | """ 16 | Rollout buffer that also stores the VPT hidden states. 17 | 18 | :param buffer_size: Max number of element in the buffer 19 | :param observation_space: Observation space 20 | :param action_space: Action space 21 | :param hidden_state_shape: Shape of the buffer that will collect states 22 | (n_steps, num_blocks, n_envs, buffer_size, hidden_size) 23 | :param device: PyTorch device 24 | :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator 25 | Equivalent to classic advantage when set to 1. 26 | :param gamma: Discount factor 27 | :param n_envs: Number of parallel environments 28 | """ 29 | 30 | def __init__( 31 | self, 32 | buffer_size: int, 33 | observation_space: spaces.Space, 34 | action_space: spaces.Space, 35 | hidden_state_shape: Tuple[int, int, int, int, int], 36 | state_buffer_size: int = 128, 37 | state_buffer_idx: int = 3, 38 | device: Union[th.device, str] = "auto", 39 | gae_lambda: float = 1, 40 | gamma: float = 0.99, 41 | n_envs: int = 1, 42 | ): 43 | self.hidden_state_shape = hidden_state_shape 44 | self.state_buffer_size = state_buffer_size 45 | self.state_buffer_idx = state_buffer_idx 46 | self.task_id_shape = (buffer_size, n_envs) 47 | self.seq_start_indices, self.seq_end_indices = None, None 48 | super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) 49 | 50 | def reset(self): 51 | super().reset() 52 | self.hidden_states_masks = np.zeros(self.hidden_state_shape[:-1], dtype=np.float32) 53 | self.hidden_states_keys = np.zeros(self.hidden_state_shape, dtype=np.float32) 54 | self.hidden_states_values = np.zeros(self.hidden_state_shape, dtype=np.float32) 55 | self.task_id = np.zeros(self.task_id_shape, dtype=np.int8) 56 | 57 | def add(self, *args, vpt_states: VPTStates, task_id: th.Tensor, **kwargs) -> None: 58 | """ 59 | :param hidden_states 60 | """ 61 | slc = (slice(None),) * (self.state_buffer_idx - 1) + (-1,) 62 | self.hidden_states_masks[self.pos] = np.array(vpt_states[0][slc].cpu().numpy()) 63 | self.hidden_states_keys[self.pos] = np.array(vpt_states[1][slc].cpu().numpy()) 64 | self.hidden_states_values[self.pos] = np.array(vpt_states[2][slc].cpu().numpy()) 65 | self.task_id[self.pos] = task_id 66 | 67 | super().add(*args, **kwargs) 68 | 69 | def get(self, batch_size: Optional[int] = None) -> Generator[VPTRolloutBufferSamples, None, None]: 70 | assert self.full, "Rollout buffer must be full before sampling from it" 71 | 72 | # Prepare the data 73 | if not self.generator_ready: 74 | # hidden_state_shape = (self.n_steps, num_blocks, self.n_envs, hidden_size) 75 | # swap first to (self.n_steps, self.n_envs, num_blocks, hidden_size) 76 | for tensor in ["hidden_states_masks", "hidden_states_keys", "hidden_states_values"]: 77 | self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) 78 | 79 | # flatten but keep the sequence order 80 | # 1. (n_steps, n_envs, *tensor_shape) -> (n_envs, n_steps, *tensor_shape) 81 | # 2. (n_envs, n_steps, *tensor_shape) -> (n_envs * n_steps, *tensor_shape) 82 | for tensor in [ 83 | "observations", 84 | "actions", 85 | "values", 86 | "log_probs", 87 | "advantages", 88 | "returns", 89 | "hidden_states_masks", 90 | "hidden_states_keys", 91 | "hidden_states_values", 92 | "task_id", 93 | "episode_starts" 94 | ]: 95 | self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) 96 | self.generator_ready = True 97 | 98 | # Return everything, don't create minibatches 99 | if batch_size is None: 100 | batch_size = self.buffer_size * self.n_envs 101 | 102 | # Sampling strategy that allows any mini batch size but requires 103 | # more complexity and use of padding 104 | # Trick to shuffle a bit: keep the sequence order 105 | # but split the indices in two 106 | split_index = np.random.randint(self.buffer_size * self.n_envs) 107 | indices = np.arange(self.buffer_size * self.n_envs) 108 | indices = np.concatenate((indices[split_index:], indices[:split_index])) 109 | 110 | env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) 111 | # Flag first timestep as change of environment 112 | env_change[0, :] = 1.0 113 | env_change = self.swap_and_flatten(env_change) 114 | 115 | start_idx = 0 116 | while start_idx < self.buffer_size * self.n_envs: 117 | batch_inds = indices[start_idx : start_idx + batch_size] 118 | yield self._get_samples(batch_inds, env_change) 119 | start_idx += batch_size 120 | 121 | def _get_samples( 122 | self, 123 | batch_inds: np.ndarray, 124 | env_change: np.ndarray, 125 | env: Optional[VecNormalize] = None, 126 | ) -> VPTRolloutBufferSamples: 127 | # Retrieve sequence starts and utility function 128 | self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers( 129 | self.episode_starts[batch_inds], env_change[batch_inds], self.device 130 | ) 131 | 132 | # Number of sequences 133 | n_seq = len(self.seq_start_indices) 134 | max_length = self.pad(self.actions[batch_inds]).shape[1] 135 | padded_batch_size = n_seq * max_length 136 | 137 | # We retrieve the hidden states that will allow proper initialization the at the beginning of each sequence 138 | eps_start_indices = np.logical_or(self.episode_starts, env_change).flatten() 139 | eps_start_indices[0] = True 140 | 141 | masks = [] 142 | keys = [] 143 | values = [] 144 | for seq_start in batch_inds[self.seq_start_indices]: 145 | eps_start = np.where(eps_start_indices[:seq_start])[0] 146 | eps_start = eps_start[-1] if len(eps_start) > 0 else 0 # If len==0, seq_start also equals 0 147 | eps_start = max(eps_start, seq_start + 1 - self.state_buffer_size) # Only need 128 sized buffer 148 | padding_size = self.state_buffer_size - (seq_start + 1 - eps_start) # May need some padding 149 | 150 | # 1, buffer, n_blocks, dim 151 | masks.append(np.expand_dims(np.concatenate(( 152 | np.zeros((padding_size, self.hidden_states_masks.shape[-1]), dtype=np.float32), 153 | self.hidden_states_masks[eps_start:seq_start+1] 154 | ), axis=0), axis=0)) 155 | keys.append(np.expand_dims(np.concatenate(( 156 | np.zeros((padding_size,) + self.hidden_states_keys.shape[-2:], dtype=np.float32), 157 | self.hidden_states_keys[eps_start:seq_start+1] 158 | ), axis=0), axis=0)) 159 | values.append(np.expand_dims(np.concatenate(( 160 | np.zeros((padding_size,) + self.hidden_states_values.shape[-2:], dtype=np.float32), 161 | self.hidden_states_values[eps_start:seq_start+1] 162 | ), axis=0), axis=0)) 163 | 164 | # (n_seq, buffer, n_blocks, dim) -> (n_blocks, n_seq, buffer, dim) 165 | masks = np.concatenate(masks, axis=0).transpose((2, 0, 1)) 166 | keys = np.concatenate(keys, axis=0).transpose((2, 0, 1, 3)) 167 | values = np.concatenate(values, axis=0).transpose((2, 0, 1, 3)) 168 | 169 | vpt_states = (self.to_torch(masks), self.to_torch(keys), self.to_torch(values)) 170 | 171 | return VPTRolloutBufferSamples( 172 | # (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim) 173 | observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size,) + self.obs_shape), 174 | actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]), 175 | old_values=self.pad_and_flatten(self.values[batch_inds]), 176 | old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), 177 | advantages=self.pad_and_flatten(self.advantages[batch_inds]), 178 | returns=self.pad_and_flatten(self.returns[batch_inds]), 179 | vpt_states=VPTStates(*vpt_states), 180 | task_id=self.pad_and_flatten(self.task_id[batch_inds]), 181 | episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), 182 | mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), 183 | ) 184 | -------------------------------------------------------------------------------- /tasks/base/techtree_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | import numpy as np 3 | from gym import Wrapper 4 | from abc import ABC, abstractmethod 5 | from minedojo.sim.mc_meta.mc import * 6 | 7 | from tech.technode import TechNode 8 | from tech.techtree import TechTree 9 | 10 | 11 | class TechTreeWrapper(Wrapper, ABC): 12 | def __init__(self, env, craft_ticks=3, max_task_steps=1000, max_goal_steps=5000, **kwargs): 13 | super().__init__(env) 14 | self.craft_ticks = craft_ticks 15 | self.max_task_steps = max_task_steps 16 | self.max_goal_steps = max_goal_steps 17 | self.techtree = TechTree(self._get_all_items(), **kwargs) 18 | 19 | self.craftables = self._get_craftables() 20 | self.curr_inventory = {x: 0 for x in self._get_all_items() + ["any"]} 21 | self.last_inventory = self.curr_inventory 22 | self.curr_goal = None 23 | self.curr_task = None 24 | self.curr_tool = None 25 | self.task_steps = 0 26 | self.goal_steps = 0 27 | self._last_obs = None 28 | self.is_techtree_active = True 29 | self._sample_goal() 30 | 31 | def reset(self, **kwargs): 32 | self.curr_inventory = {x: 0 for x in self._get_all_items() + ["any"]} 33 | self.last_inventory = self.curr_inventory 34 | self._last_obs = None 35 | self._sample_goal() 36 | return self.env.reset(**kwargs) 37 | 38 | def step(self, action): 39 | if not self.is_techtree_active: 40 | return self.env.step(action) 41 | 42 | action["drop"] = 0 # manually disable dropping and placing items 43 | if self.curr_task is not None and self.curr_task.tool is not None: 44 | action = self._equip_item(self._last_obs, action, self.curr_task.tool.name) 45 | 46 | obs, reward, done, info = self.env.step(action) 47 | self.techtree.tick() 48 | 49 | # Gather info 50 | self.goal_steps += 1 51 | self.task_steps += 1 52 | self.curr_inventory = self._get_techtree_inventory(obs) 53 | 54 | ######### Tech Tree Logic ######### 55 | 56 | # Reward gathering items for current task 57 | reward += self._get_reward() 58 | 59 | # Add new items we've collected to the graph 60 | self.techtree.update_collectables(self.last_inventory, self.curr_inventory) 61 | 62 | if self.curr_task is not None and self.curr_inventory[self.curr_task.name] > self.last_inventory[self.curr_task.name]: 63 | # print("\tCollected {} ({})".format(self.curr_task.name, self.curr_inventory[self.curr_task.name])) 64 | info["untrained"] = self.curr_task.name if self.curr_task.name not in ["dirt", "any"] and \ 65 | self.curr_task.name not in self.techtree.tasks and \ 66 | self.techtree.get_node_by_name(self.curr_task.name) is not None else "" 67 | 68 | # If we're working on a tool and have the ingredients, craft it 69 | if self.curr_tool is not None: 70 | obs, reward, done, info = self._craft(self.curr_tool, obs, reward, done, info) 71 | if self.curr_inventory[self.curr_tool.name] > 0: 72 | self.curr_tool = None 73 | 74 | # If we have ingredients for the current goal, craft it 75 | if not self.curr_goal.collectable: 76 | obs, reward, done, info = self._craft(self.curr_goal, obs, reward, done, info) 77 | 78 | # If we successfully obtained the current goal, try crafting something new, then sample a new goal 79 | if self._check_goal_success(): 80 | if not self.techtree.has_target(): 81 | obs, reward, done, info = self._try_craft_new_item(obs, reward, done, info) 82 | self._sample_goal(completed=True, success=True) 83 | 84 | # If we're still looking for a collectable and we're not crafting a tool, try crafting new items to expand the graph 85 | elif self.curr_goal.collectable and self.curr_tool is None: 86 | if not self.techtree.has_target(): 87 | obs, reward, done, info = self._try_craft_new_item(obs, reward, done, info) 88 | # If we haven't found the collectable yet, sample a new goal 89 | if self.task_steps >= self.max_task_steps: 90 | self._sample_goal(completed=True) 91 | 92 | else: 93 | # If we've been unable to craft the current goal, sample a new goal 94 | if self.goal_steps >= self.max_goal_steps: 95 | obs, reward, done, info = self._try_craft_new_item(obs, reward, done, info) 96 | self._sample_goal(completed=True) 97 | # If we haven't found current task yet, try looking for something else for the current goal 98 | if self.task_steps >= self.max_task_steps: 99 | self.curr_task = None 100 | 101 | # Update what we're currently looking for, if it requires a tool, that takes priority 102 | self._update_task() 103 | 104 | ######### End Tech Tree Logic ######### 105 | 106 | # Info to pass to algorithm 107 | info["subgoal"] = self.techtree.get_task_id(self.curr_task) 108 | info["timestep"] = self.techtree.total_steps 109 | info["early_stop"] = self.techtree.get_node_by_name(self.techtree.target_item) is not None \ 110 | if self.techtree.target_item is not None else False 111 | 112 | self._last_obs = obs.copy() 113 | self.last_inventory = self.curr_inventory 114 | return obs, reward, done, info 115 | 116 | def add_task(self, new_task): 117 | self.techtree.tasks.append(new_task) 118 | 119 | def _sample_goal(self, completed=False, success=False): 120 | self.curr_goal = self.techtree.sample_goal(last_goal=self.curr_goal, completed=completed, success=success) 121 | self.curr_task = None 122 | self.goal_steps = 0 123 | self.task_steps = 0 124 | 125 | def _update_task(self): 126 | new_task, self.curr_tool = self.techtree.update_task( 127 | self.curr_goal, self.curr_task, self.curr_tool, self.curr_inventory 128 | ) 129 | if new_task != self.curr_task: 130 | self.task_steps = 0 131 | self.curr_task = new_task 132 | 133 | def _craft(self, node: TechNode, obs, reward, done, info): 134 | crafts = node.get_craft_order(self.curr_inventory) 135 | 136 | crafting = True 137 | while crafting: 138 | 139 | crafting = False 140 | for to_craft in crafts: 141 | 142 | if not done and self._check_ingredients(to_craft): 143 | action = self._get_craft_action( 144 | to_craft.name, 145 | self.curr_inventory["crafting_table"] > 0, 146 | self.curr_inventory["furnace"] > 0 147 | ) 148 | if action is not None: 149 | obs, reward, done, info = self.env.step(action) 150 | for _ in range(self.craft_ticks): 151 | if not done: 152 | obs, reward, done, info = self.env.step(self._get_noop_action()) 153 | self.techtree.tick(1 + self.craft_ticks) 154 | 155 | inventory = self._get_techtree_inventory(obs) 156 | if inventory[to_craft.name] > self.curr_inventory[to_craft.name]: 157 | # print("\tCrafted {} ({})".format(to_craft.name, inventory[to_craft.name])) 158 | if to_craft not in self.techtree.get_all_nodes(): 159 | self.techtree.add_craft(to_craft.name, self.curr_inventory, inventory) 160 | self.curr_inventory = inventory 161 | crafts = node.get_craft_order(self.curr_inventory) 162 | crafting = to_craft != node 163 | break 164 | 165 | else: 166 | self.curr_inventory = inventory 167 | 168 | return obs, reward, done, info 169 | 170 | def _try_craft_new_item(self, obs, reward, done, info): 171 | if self.techtree.check_seen_inv(self.curr_inventory): 172 | return obs, reward, done, info 173 | 174 | success = False 175 | crafted = [x.name for x in self.techtree.get_all_nodes() if not x.collectable] 176 | names = [x for x in self.craftables if x in self.curr_inventory and x not in crafted] 177 | np.random.shuffle(names) 178 | for name in names: 179 | if not done: 180 | action = self._get_craft_action( 181 | name, 182 | self.curr_inventory["crafting_table"] > 0, 183 | self.curr_inventory["furnace"] > 0, 184 | [n.name for n in self.techtree.get_all_nodes() if not n.collectable or n.name in self.techtree.tasks] 185 | ) 186 | if action is not None: 187 | obs, reward, done, info = self.env.step(action) 188 | for _ in range(self.craft_ticks): 189 | if not done: 190 | obs, reward, done, info = self.env.step(self._get_noop_action()) 191 | self.techtree.tick(1 + self.craft_ticks) 192 | inventory = self._get_techtree_inventory(obs) 193 | if inventory[name] > self.curr_inventory[name]: 194 | # print("\tCrafted New {} ({})".format(name, inventory[name])) 195 | self.techtree.add_craft(name, self.curr_inventory, inventory) 196 | success = True 197 | break 198 | if done: 199 | break 200 | 201 | if not success: 202 | self.techtree.add_seen_invs(self.curr_inventory) 203 | 204 | self.curr_inventory = self._get_techtree_inventory(obs) 205 | return obs, reward, done, info 206 | 207 | def _check_ingredients(self, node: TechNode) -> bool: 208 | return sum(x for x in node.get_ingredients(self.curr_inventory).values()) == 0 and \ 209 | (not node.table or self.curr_inventory["crafting_table"] > 0) and \ 210 | (not node.furnace or self.curr_inventory["furnace"] > 0) 211 | 212 | def _get_reward(self) -> float: 213 | if self.curr_task is not None and self.curr_task.name != "any": 214 | return int(self.curr_inventory[self.curr_task.name] > self.last_inventory[self.curr_task.name]) 215 | else: 216 | return 0 217 | 218 | def _check_goal_success(self, goal: TechNode = None) -> bool: 219 | if goal is None: 220 | goal = self.curr_goal 221 | return self.curr_inventory[goal.name] > self.last_inventory[goal.name] 222 | 223 | def _get_techtree_inventory(self, obs: Dict) -> Dict[str, int]: 224 | inventory = self._get_inventory(obs) 225 | prev_count = sum(self.last_inventory.values()) if self.last_inventory is not None else 0 226 | inventory["any"] = max(0, sum(inventory.values()) - prev_count) 227 | return inventory 228 | 229 | @abstractmethod 230 | def _get_inventory(obs: Dict) -> Dict[str, int]: 231 | # retrieves a dict of item names to quantities for every game item 232 | raise NotImplementedError() 233 | 234 | @abstractmethod 235 | def _get_all_items(self) -> List[str]: 236 | # retrieves a list of all item names 237 | raise NotImplementedError() 238 | 239 | @abstractmethod 240 | def _get_craftables(self) -> List[str]: 241 | # retrieves a list of craftable item names 242 | raise NotImplementedError() 243 | 244 | @abstractmethod 245 | def _get_noop_action(self) -> Dict: 246 | # the no op action 247 | raise NotImplementedError() 248 | 249 | @abstractmethod 250 | def _get_craft_action(self, item, crafting_table=False, furnace=False, valid=None) -> Dict: 251 | # retrieves the craft action for the given item with the given tools 252 | raise NotImplementedError() 253 | 254 | @abstractmethod 255 | def _equip_item(self, last_obs: Dict, action: Dict, item: str) -> Dict: 256 | # retrieves a dict of item names to game actions 257 | raise NotImplementedError() 258 | -------------------------------------------------------------------------------- /tech/techtree.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple 2 | import numpy as np 3 | import json 4 | import os 5 | from copy import deepcopy 6 | from filelock import FileLock 7 | 8 | from tech import TECH_ORDER 9 | from tech.technode import TechNode 10 | 11 | 12 | class TechTree: 13 | def __init__(self, all_items, log_dir=None, tasks=dict(), guide_path=None, target_item=None, init_graph=None, **kwargs): 14 | self.techtree_lock = FileLock(os.path.join(log_dir, "techtree.json.lock"), timeout=10) 15 | self.techtree_path = os.path.join(log_dir, "techtree.json") 16 | self.target_item = target_item 17 | self.steps_since_sync = 0 18 | self.total_steps = 0 19 | self.iterations = 0 20 | self.seen_invs = [] 21 | self.node_visits = {x: 0 for x in all_items + ["any"]} 22 | self.node_success = {x: 1 for x in all_items + ["any"]} 23 | 24 | self.guide_trees = [] 25 | self.trees = [] 26 | self.tasks = list(tasks.keys()) 27 | 28 | # Initialize trees if given 29 | if init_graph is not None: 30 | with self.techtree_lock: 31 | self._sync_info() 32 | self.trees = [TechNode.from_json(x) for x in init_graph] 33 | self._save_info() 34 | 35 | # Initialize with exploration task 36 | self._add_node(TechNode("any", collectable=True)) 37 | 38 | # If using a guide, retrieve predicted graph 39 | if guide_path is not None: 40 | with open(guide_path, "r") as f: 41 | self.guide_trees = [TechNode.from_json(x) for x in json.load(f)] 42 | 43 | # Remove unknown items from guidance 44 | unknown_items = [x.name for x in self.get_all_nodes(guide=True) if x.name not in all_items] 45 | self.guide_trees = [x for x in self.guide_trees if x.name not in unknown_items] 46 | for name in unknown_items: 47 | for root in self.guide_trees: 48 | root.purge(name) 49 | 50 | def tick(self, n: int = 1) -> None: 51 | self.steps_since_sync += n 52 | 53 | def get_task_id(self, task: TechNode) -> int: 54 | if task is not None and task.name in self.tasks and self.get_node_by_name(task.name) is not None: 55 | return self.tasks.index(task.name) 56 | return -1 57 | 58 | def has_target(self) -> bool: 59 | # Check if there is an unexplored boundary 60 | # Visit count can be 1 if agent is currently exploring it 61 | return self.target_item is not None and any(self.node_visits[n.name] <= 1 for n in self._get_boundary()) 62 | 63 | def get_all_nodes(self, guide: bool = False) -> List[TechNode]: 64 | trees = self.guide_trees if guide else self.trees 65 | all_nodes = set() 66 | for root in trees: 67 | all_nodes.update([root] + root.get_subnodes()) 68 | return list(all_nodes) 69 | 70 | def get_node_by_name(self, name: str, guide: bool = False) -> TechNode: 71 | matches = [x for x in self.get_all_nodes(guide=guide) if x.name == name] 72 | if len(matches) == 0: 73 | return None 74 | return matches[0] 75 | 76 | def update_collectables(self, prev_inv: Dict[str, int], curr_inv: Dict[str, int]) -> None: 77 | collected = set() 78 | for n, q in curr_inv.items(): 79 | if q > prev_inv[n]: 80 | collected.add(n) 81 | for name in collected: 82 | if self.get_node_by_name(name) is None: 83 | tool = None 84 | for n, q in prev_inv.items(): 85 | # For now we only care if a pickaxe is required 86 | if q > 0 and "pickaxe" in n and (tool is None or self._get_tech_rank(n) > self._get_tech_rank(tool.name)): 87 | tool = self.get_node_by_name(n) 88 | new_node = TechNode( 89 | name, 90 | collectable=True, 91 | tool=tool, 92 | timestep=self.total_steps+self.steps_since_sync, 93 | iteration=self.iterations 94 | ) 95 | self._add_node(new_node) 96 | 97 | def check_seen_inv(self, inv: Dict[str, int]) -> bool: 98 | return sum(q for q in inv.values()) == 0 or len(self.seen_invs) > 0 and \ 99 | any(all( 100 | q <= seen[n] if n in seen else False 101 | for n, q in inv.items() if q > 0 and n != "any" 102 | ) for seen in self.seen_invs) 103 | 104 | def add_craft(self, name, prev_inv: Dict[str, int], curr_inv: Dict[str, int]) -> None: 105 | recipe = [] 106 | for n, q in curr_inv.items(): 107 | if n != "any" and q < prev_inv[n]: 108 | node = self.get_node_by_name(n) 109 | assert node is not None, "Crafted {} with item ({}) not in dag".format(name, n) 110 | recipe.append((node, prev_inv[n] - q)) 111 | new_node = TechNode( 112 | name, 113 | recipe=recipe, 114 | table=prev_inv["crafting_table"] > 0, 115 | furnace=prev_inv["furnace"] > 0, 116 | timestep=self.total_steps+self.steps_since_sync, 117 | iteration=self.iterations 118 | ) 119 | self._add_node(new_node) 120 | 121 | def add_seen_invs(self, inv: Dict[str, int]) -> None: 122 | with self.techtree_lock: 123 | self._sync_info() 124 | filterd_inv = {n: q for n, q in inv.items() if q > 0 and n != "any"} 125 | found = False 126 | to_remove = [] 127 | for i, seen in enumerate(self.seen_invs): 128 | if set(seen.keys()).issubset(set(filterd_inv.keys())): 129 | if not found: 130 | for item in filterd_inv: 131 | seen[item] = max(seen[item] if item in seen else 0, filterd_inv[item]) 132 | found = True 133 | else: 134 | to_remove.append(i) 135 | for i in reversed(to_remove): 136 | self.seen_invs.pop(i) 137 | if not found: 138 | self.seen_invs.append(filterd_inv) 139 | self._save_info() 140 | 141 | def sample_goal(self, last_goal=None, completed=False, success=False) -> TechNode: 142 | 143 | nodes = self._get_candidate_nodes() 144 | with self.techtree_lock: 145 | self._sync_info() 146 | if last_goal is not None: 147 | if completed: 148 | self.node_success[last_goal.name] = .9 * self.node_success[last_goal.name] + .1 * success 149 | else: 150 | self.node_visits[last_goal.name] = max(0, self.node_visits[last_goal.name] - 1) 151 | 152 | nodes = [n for n in nodes if self.node_success[n.name] > -float("inf")] 153 | if any(self.node_visits[s.name] <= 0 for s in nodes): 154 | nodes = [n for n in nodes if self.node_visits[n.name] <= 0] 155 | 156 | goal = np.random.choice(nodes) 157 | 158 | self.iterations += 1 159 | self.node_visits[goal.name] += 1 160 | 161 | self._save_info() 162 | 163 | # print("Sampled goal:", goal.name) 164 | return goal 165 | 166 | def update_task(self, goal: TechNode, task: TechNode, tool: TechNode, inv: Dict[str, int]) -> Tuple[TechNode, TechNode, Dict[TechNode, int]]: 167 | assert goal is not None 168 | 169 | ingredients = tool.get_ingredients(inv) if tool is not None else goal.get_ingredients(inv) 170 | if task is not None and task in ingredients and ingredients[task] > 0: 171 | return task, tool 172 | 173 | # Get nodes for benches 174 | crafting_table = self.get_node_by_name("crafting_table") 175 | if crafting_table is None and len(self.guide_trees) > 0: 176 | crafting_table = self.get_node_by_name("crafting_table", guide=True) 177 | furnace = self.get_node_by_name("furnace") 178 | if furnace is None and len(self.guide_trees) > 0: 179 | furnace = self.get_node_by_name("furnace", guide=True) 180 | 181 | # Sample new subgoal 182 | tool = None 183 | new_task = goal 184 | while not new_task.collectable or (new_task.tool is not None and inv[new_task.tool.name] == 0): 185 | ingredients = new_task.get_ingredients(inv) 186 | 187 | subgoals = set([n for n in ingredients.keys() if n.tool is None]) 188 | for n in ingredients: 189 | if n.tool is None or inv[n.tool.name] > 0: 190 | subgoals.add(n) 191 | else: 192 | subgoals.add(n.tool) 193 | if inv["crafting_table"] == 0 and any(n.table for n in new_task.get_craft_order()): 194 | subgoals.add(crafting_table) 195 | if inv["furnace"] == 0 and any(n.furnace for n in new_task.get_craft_order()): 196 | subgoals.add(furnace) 197 | 198 | if len(subgoals) > 0: 199 | new_task = np.random.choice(list(subgoals)) 200 | else: 201 | # If using a guide, we may have the recipe wrong. Sample from the known tree to explore. 202 | new_task = np.random.choice([n for n in self.get_all_nodes() if self.node_success[n.name] > -float("inf")]) 203 | # print("Failed to find task. Doing {} instead".format(new_task.name)) 204 | 205 | if not new_task.collectable: 206 | tool = new_task 207 | 208 | # print("Current task/tool:", new_task.name if new_task is not None else "None", 209 | # tool.name if tool is not None else "no tool") 210 | return new_task, tool 211 | 212 | def _get_boundary(self) -> List[TechNode]: 213 | nodes = set() 214 | if len(self.guide_trees) > 0: 215 | if self.target_item is not None: 216 | table_node = self.get_node_by_name("crafting_table", guide=True) 217 | furnace_node = self.get_node_by_name("furnace", guide=True) 218 | guide_nodes = [] 219 | to_add = [self.get_node_by_name(self.target_item, guide=True)] 220 | while len(to_add) > 0: 221 | guide_nodes += to_add 222 | last_nodes = to_add.copy() 223 | to_add = [] 224 | for node in last_nodes: 225 | if node.tool is not None and node.tool not in guide_nodes: 226 | to_add.append(node.tool) 227 | if node.table and table_node not in guide_nodes: 228 | to_add.append(table_node) 229 | if node.furnace and furnace_node not in guide_nodes: 230 | to_add.append(furnace_node) 231 | for n in node.get_subnodes(): 232 | if n not in guide_nodes: 233 | to_add.append(n) 234 | else: 235 | guide_nodes = self.get_all_nodes(guide=True) 236 | 237 | for node in guide_nodes: 238 | if self.get_node_by_name(node.name) is None and \ 239 | (node.tool is None or self.get_node_by_name(node.tool.name) is not None) and \ 240 | (not node.table or self.get_node_by_name("crafting_table") is not None) and \ 241 | (not node.furnace or self.get_node_by_name("furnace") is not None) and \ 242 | (node.collectable or all(x.name in self.tasks for x in node.get_ingredients())) and \ 243 | all(self.get_node_by_name(x.name) is not None for x, _ in node.recipe) and \ 244 | all(x.get_requirements() != node.get_requirements() for x in nodes): 245 | nodes.add(node) 246 | 247 | return list(nodes) 248 | 249 | def _get_candidate_nodes(self) -> List[TechNode]: 250 | if self.target_item is not None: 251 | target_node = self.get_node_by_name(self.target_item) 252 | if target_node is not None: 253 | return [target_node] 254 | nodes = self._get_boundary() 255 | nodes += self.get_all_nodes() 256 | return nodes 257 | 258 | def _get_tech_rank(self, name: str) -> int: 259 | return ([0] + [i + 1 for i, x in enumerate(TECH_ORDER) if name is not None and x in name])[-1] 260 | 261 | def _check_tool(self, tool: TechNode, inv: Dict[str, int]) -> bool: 262 | if "_" in tool.name and tool.name.split("_")[0] in TECH_ORDER: 263 | tech_name = tool.name.split("_")[0] 264 | base_name = tool.name.split("_")[1] 265 | sufficient_tools = [x + "_" + base_name for x in TECH_ORDER \ 266 | if TECH_ORDER.index(x) >= TECH_ORDER.index(tech_name)] 267 | else: 268 | sufficient_tools = [tool.name] 269 | return any(inv[x] > 0 for x in sufficient_tools) 270 | 271 | def _add_node(self, node: TechNode) -> None: 272 | with self.techtree_lock: 273 | self._sync_info() 274 | 275 | if any(node == x or node in x.get_subnodes() for x in self.trees): 276 | return 277 | for child, _ in node.recipe: 278 | if child in self.trees: 279 | self.trees.remove(child) 280 | self.trees.append(node) 281 | 282 | # Update guide dag with known dag 283 | for root in self.guide_trees: 284 | root.update_children_info([node]) 285 | 286 | # Update node success with pretrained tasks 287 | if node.name != "any" and any(x.name not in self.tasks for x in node.get_ingredients().keys()): 288 | self.node_success[node.name] = -float("inf") 289 | 290 | self._save_info() 291 | 292 | def _sync_info(self) -> None: 293 | if not os.path.exists(self.techtree_path): 294 | with open(self.techtree_path, "w") as f: 295 | json.dump(dict( 296 | trees=[], seen_invs=[], total_steps=0, iterations=0, node_success=dict(), node_visits=dict() 297 | ), f) 298 | 299 | with open(self.techtree_path, "r") as f: 300 | stored = json.load(f) 301 | self.trees = [TechNode.from_json(x) for x in stored["trees"]] 302 | self.seen_invs = stored["seen_invs"] 303 | self.total_steps = stored["total_steps"] 304 | self.iterations = stored["iterations"] 305 | self.node_success = { 306 | k: stored["node_success"][k] if k in stored["node_success"] else 1 307 | for k in self.node_success.keys() 308 | } 309 | self.node_visits = { 310 | k: stored["node_visits"][k] if k in stored["node_visits"] else 0 311 | for k in self.node_visits.keys() 312 | } 313 | 314 | # Update guide dag with known dag 315 | all_nodes = self.get_all_nodes() 316 | for root in self.guide_trees: 317 | root.update_children_info(all_nodes) 318 | 319 | def _save_info(self) -> None: 320 | self.total_steps += self.steps_since_sync 321 | self.steps_since_sync = 0 322 | with open(self.techtree_path, "w") as f: 323 | json.dump(dict( 324 | trees=[x.to_json() for x in self.trees], 325 | seen_invs=self.seen_invs, 326 | total_steps=self.total_steps, 327 | iterations=self.iterations, 328 | node_success={k: v for k, v in self.node_success.items() if v != 1}, 329 | node_visits={k: v for k, v in self.node_visits.items() if v > 0}, 330 | ), f, indent=4) 331 | -------------------------------------------------------------------------------- /sb3_vpt/algorithm.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import gym 3 | from gym import spaces 4 | import torch as th 5 | import numpy as np 6 | from copy import deepcopy 7 | 8 | from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor 9 | from stable_baselines3.common.vec_env import VecEnv 10 | from stable_baselines3.common.buffers import RolloutBuffer 11 | from stable_baselines3.common.callbacks import BaseCallback 12 | 13 | from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy 14 | from sb3_contrib.ppo_recurrent.ppo_recurrent import RecurrentPPO 15 | from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer 16 | 17 | from sb3_vpt.buffer import VPTBuffer 18 | 19 | 20 | class VPTPPO(RecurrentPPO): 21 | def __init__(self, *args, kl_coef=.2, kl_decay=.9995, use_task_ids=False, n_tasks=1, **kwargs): 22 | self.init_kl_coef = kl_coef 23 | self.kl_coef = self.init_kl_coef 24 | self.kl_decay = kl_decay 25 | self.use_task_ids = use_task_ids 26 | self.n_tasks = n_tasks 27 | super().__init__(*args, **kwargs) 28 | 29 | def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: 30 | return ["policy.optimizer"] + ["policy." + p for p in self.policy.get_param_keys()], [] 31 | 32 | def set_task_id(self, task_id): 33 | self._last_task_id = np.full((self.n_envs,), task_id, dtype=np.int16) 34 | 35 | def _setup_model(self) -> None: 36 | self._setup_lr_schedule() 37 | self.set_random_seed(self.seed) 38 | 39 | self.policy = self.policy_class( 40 | self.observation_space, 41 | self.action_space, 42 | self.lr_schedule, 43 | use_sde=self.use_sde, 44 | **self.policy_kwargs, # pytype:disable=not-instantiable 45 | ) 46 | self.policy = self.policy.to(self.device) 47 | 48 | if not isinstance(self.policy, RecurrentActorCriticPolicy): 49 | raise ValueError("Policy must subclass RecurrentActorCriticPolicy") 50 | 51 | buffer_cls = VPTBuffer 52 | 53 | self._last_vpt_states = self.policy.initial_state(self.n_envs) # num_blocks x batch x buffer, num_blocks x batch x buffer x hidden 54 | self._last_task_id = np.zeros((self.n_envs,), dtype=np.int16) 55 | 56 | hidden_state_buffer_shape = ( 57 | self.n_steps, 58 | self._last_vpt_states[1].shape[0], # num_blocks 59 | self.n_envs, 60 | self._last_vpt_states[1].shape[3] # hidden size 61 | ) 62 | 63 | self.rollout_buffer = buffer_cls( 64 | self.n_steps, 65 | self.observation_space, 66 | self.action_space, 67 | hidden_state_buffer_shape, 68 | device=self.device, 69 | gamma=self.gamma, 70 | gae_lambda=self.gae_lambda, 71 | n_envs=self.n_envs, 72 | ) 73 | 74 | # Initialize schedules for policy/value clipping 75 | self.clip_range = get_schedule_fn(self.clip_range) 76 | if self.clip_range_vf is not None: 77 | if isinstance(self.clip_range_vf, (float, int)): 78 | assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, pass `None` to deactivate vf clipping" 79 | 80 | self.clip_range_vf = get_schedule_fn(self.clip_range_vf) 81 | 82 | def reset(self): 83 | self._last_obs = self.env.reset() 84 | self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool) 85 | self._last_vpt_states = self.policy.initial_state(self.n_envs) 86 | self._last_task_id = np.zeros((self.n_envs,), dtype=np.int16) 87 | self.kl_coef = self.init_kl_coef 88 | 89 | def collect_rollouts( 90 | self, 91 | env: VecEnv, 92 | callback: BaseCallback, 93 | rollout_buffer: RolloutBuffer, 94 | n_rollout_steps: int, 95 | ) -> bool: 96 | """ 97 | Collect experiences using the current policy and fill a ``RolloutBuffer``. 98 | The term rollout here refers to the model-free notion and should not 99 | be used with the concept of rollout used in model-based RL or planning. 100 | :param env: The training environment 101 | :param callback: Callback that will be called at each step 102 | (and at the beginning and end of the rollout) 103 | :param rollout_buffer: Buffer to fill with rollouts 104 | :param n_steps: Number of experiences to collect per environment 105 | :return: True if function returned with at least `n_rollout_steps` 106 | collected, False if callback terminated rollout prematurely. 107 | """ 108 | assert isinstance( 109 | rollout_buffer, (VPTBuffer, RecurrentRolloutBuffer, RecurrentDictRolloutBuffer) 110 | ), f"{rollout_buffer} doesn't support recurrent policy" 111 | 112 | assert self._last_obs is not None, "No previous observation was provided" 113 | # Switch to eval mode (this affects batch norm / dropout) 114 | self.policy.set_training_mode(False) 115 | 116 | n_steps = 0 117 | rollout_buffer.reset() 118 | # Sample new weights for the state dependent exploration 119 | if self.use_sde: 120 | self.policy.reset_noise(env.num_envs) 121 | 122 | callback.on_rollout_start() 123 | 124 | vpt_states = deepcopy(self._last_vpt_states) 125 | 126 | while n_steps < n_rollout_steps: 127 | if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: 128 | # Sample a new noise matrix 129 | self.policy.reset_noise(env.num_envs) 130 | 131 | with th.no_grad(): 132 | # Convert to pytorch tensor or to TensorDict 133 | obs_tensor = obs_as_tensor(self._last_obs, self.device) 134 | episode_starts = th.tensor(self._last_episode_starts).float().to(self.device) 135 | actions, values, log_probs, vpt_states = self.policy.forward(obs_tensor, vpt_states, episode_starts, self._last_task_id) 136 | 137 | actions = actions.cpu().numpy() # n_envs x 2 138 | 139 | # Rescale and perform action 140 | clipped_actions = actions 141 | # Clip the actions to avoid out of bound error 142 | if isinstance(self.action_space, gym.spaces.Box): 143 | clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) 144 | 145 | new_obs, rewards, dones, infos = env.step(clipped_actions) 146 | 147 | self.num_timesteps += env.num_envs 148 | 149 | # Give access to local variables 150 | callback.update_locals(locals()) 151 | if callback.on_step() is False: 152 | return False 153 | 154 | self._update_info_buffer(infos) 155 | n_steps += 1 156 | 157 | if isinstance(self.action_space, gym.spaces.Discrete): 158 | # Reshape in case of discrete action 159 | actions = actions.reshape(-1, 1) 160 | 161 | # Update dones with task terminals after callbacks 162 | curr_task_id = np.array([ 163 | info["subgoal"] if "subgoal" in info and self.use_task_ids else self._last_task_id[i] 164 | for i, info in enumerate(infos) 165 | ], dtype=np.int16) 166 | dones = np.bitwise_or(dones, curr_task_id != self._last_task_id) 167 | 168 | # Handle timeout by bootstraping with value function 169 | # see GitHub issue #633 170 | for idx, done_ in enumerate(dones): 171 | if ( 172 | done_ 173 | and infos[idx].get("terminal_observation") is not None 174 | and infos[idx].get("TimeLimit.truncated", False) 175 | ): 176 | terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] 177 | with th.no_grad(): 178 | ### VPT Changes Start: alter state field names 179 | terminal_vpt_state = ( 180 | vpt_states.mask[:, idx : idx + 1], 181 | vpt_states.keys[:, idx : idx + 1, :], 182 | vpt_states.values[:, idx : idx + 1, :], 183 | ) 184 | ### VPT Changes end 185 | episode_starts = th.tensor([False]).float().to(self.device) 186 | terminal_task_id = th.tensor([infos[idx]["subgoal"]], dtype=th.int8, device=self.device) 187 | terminal_value = self.policy.predict_values(terminal_obs, terminal_vpt_state, episode_starts, terminal_task_id)[0] 188 | rewards[idx] += self.gamma * terminal_value 189 | 190 | rollout_buffer.add( 191 | self._last_obs, 192 | actions, 193 | rewards, 194 | self._last_episode_starts, 195 | values, 196 | log_probs, 197 | vpt_states=self._last_vpt_states, 198 | task_id=self._last_task_id 199 | ) 200 | 201 | self._last_obs = new_obs 202 | self._last_episode_starts = dones 203 | self._last_vpt_states = vpt_states 204 | self._last_task_id = curr_task_id 205 | 206 | with th.no_grad(): 207 | # Compute value for the last timestep 208 | episode_starts = th.tensor(dones).float().to(self.device) 209 | ### VPT Changes Start: pass entire state to policy 210 | values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), vpt_states, episode_starts, self._last_task_id) 211 | ### VPT Changes End 212 | 213 | rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) 214 | 215 | callback.on_rollout_end() 216 | 217 | return True 218 | 219 | def train(self) -> None: 220 | """ 221 | Update policy using the currently gathered rollout buffer. 222 | """ 223 | # Switch to train mode (this affects batch norm / dropout) 224 | self.policy.set_training_mode(True) 225 | # Update optimizer learning rate 226 | self._update_learning_rate(self.policy.optimizer) 227 | # Compute current clip range 228 | clip_range = self.clip_range(self._current_progress_remaining) 229 | # Optional: clip range for the value function 230 | if self.clip_range_vf is not None: 231 | clip_range_vf = self.clip_range_vf(self._current_progress_remaining) 232 | 233 | pg_losses, value_losses, kl_losses, losses, bc_losses = [], [], [], [], [] 234 | clip_fractions = [] 235 | 236 | continue_training = True 237 | 238 | # train for n_epochs epochs 239 | for epoch in range(self.n_epochs): 240 | approx_kl_divs = [] 241 | # Do a complete pass on the rollout buffer 242 | for rollout_data in self.rollout_buffer.get(self.batch_size): 243 | actions = rollout_data.actions 244 | if isinstance(self.action_space, spaces.Discrete): 245 | # Convert discrete action from float to long 246 | actions = rollout_data.actions.long().flatten() 247 | 248 | # Convert mask from float to bool 249 | mask = rollout_data.mask > 1e-8 250 | 251 | # Re-sample the noise matrix because the log_std has changed 252 | if self.use_sde: 253 | self.policy.reset_noise(self.batch_size) 254 | 255 | values, log_prob, kl = self.policy.evaluate_actions( 256 | rollout_data.observations, 257 | actions, 258 | rollout_data.vpt_states, # 4, 1, 128, 2048 259 | rollout_data.episode_starts, 260 | rollout_data.task_id 261 | ) 262 | 263 | values = values.flatten() 264 | # Normalize advantage 265 | advantages = rollout_data.advantages 266 | if self.normalize_advantage: 267 | advantages = (advantages - advantages[mask].mean()) / (advantages[mask].std() + 1e-8) 268 | 269 | # ratio between old and new policy, should be one at the first iteration 270 | ratio = th.exp(log_prob - rollout_data.old_log_prob) 271 | 272 | # clipped surrogate loss 273 | policy_loss_1 = advantages * ratio 274 | policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) 275 | policy_loss = -th.mean(th.min(policy_loss_1, policy_loss_2)[mask]) 276 | 277 | # Logging 278 | pg_losses.append(policy_loss.item()) 279 | clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()[mask]).item() 280 | clip_fractions.append(clip_fraction) 281 | 282 | if self.clip_range_vf is None: 283 | # No clipping 284 | values_pred = values 285 | else: 286 | # Clip the different between old and new value 287 | # NOTE: this depends on the reward scaling 288 | values_pred = rollout_data.old_values + th.clamp( 289 | values - rollout_data.old_values, -clip_range_vf, clip_range_vf 290 | ) 291 | # Value loss using the TD(gae_lambda) target 292 | value_loss = th.mean(((rollout_data.returns - values_pred) ** 2)[mask]) 293 | 294 | value_losses.append(value_loss.item()) 295 | 296 | kl_loss = th.mean(kl) 297 | kl_losses.append(kl_loss.item()) 298 | 299 | loss = policy_loss + self.kl_coef * kl_loss + self.vf_coef * value_loss 300 | losses.append(loss.item()) 301 | 302 | with th.no_grad(): 303 | log_ratio = log_prob - rollout_data.old_log_prob 304 | approx_kl_div = th.mean(((th.exp(log_ratio) - 1) - log_ratio)[mask]).cpu().numpy() 305 | approx_kl_divs.append(approx_kl_div) 306 | 307 | if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: 308 | continue_training = False 309 | if self.verbose >= 1: 310 | print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") 311 | break 312 | 313 | # Optimization step 314 | self.policy.optimizer.zero_grad() 315 | loss.backward() 316 | # Clip grad norm 317 | th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) 318 | self.policy.optimizer.step() 319 | 320 | if not continue_training: 321 | break 322 | 323 | self.kl_coef *= self.kl_decay 324 | self._n_updates += self.n_epochs 325 | explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) 326 | 327 | # Logs 328 | if len(bc_losses) > 0: 329 | self.logger.record("train/bc_loss", np.mean(bc_losses)) 330 | self.logger.record("train/kl_loss", np.mean(kl_losses)) 331 | self.logger.record("train/kl_coef", self.kl_coef) 332 | self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) 333 | self.logger.record("train/value_loss", np.mean(value_losses)) 334 | self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) 335 | self.logger.record("train/clip_fraction", np.mean(clip_fractions)) 336 | self.logger.record("train/avg_loss", np.mean(losses)) 337 | self.logger.record("train/loss", loss.item()) 338 | self.logger.record("train/explained_variance", explained_var) 339 | if hasattr(self.policy, "log_std"): 340 | self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) 341 | 342 | self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") 343 | self.logger.record("train/clip_range", clip_range) 344 | if self.clip_range_vf is not None: 345 | self.logger.record("train/clip_range_vf", clip_range_vf) 346 | --------------------------------------------------------------------------------