├── banner.png ├── setup.py ├── homegrid ├── assets │ ├── chair.png │ ├── egg_0.png │ ├── egg_1.png │ ├── egg_2.png │ ├── egg_3.png │ ├── lamp.png │ ├── plant.png │ ├── robot.png │ ├── rugl.png │ ├── rugr.png │ ├── sofa.png │ ├── stove.png │ ├── table.png │ ├── tile.png │ ├── wood.png │ ├── cabinet.png │ ├── carpet.png │ ├── chairl.png │ ├── chairr.png │ ├── cupboard.png │ ├── fridge.png │ ├── tomato_0.png │ ├── tomato_1.png │ ├── tomato_2.png │ ├── countertop.png │ ├── sofa_side.png │ ├── coffeetable.png │ ├── objects │ │ ├── fruit.png │ │ ├── bottle.png │ │ ├── papers.png │ │ └── plates.png │ └── bins │ │ ├── trash_bin.png │ │ ├── compost_bin.png │ │ ├── recycling_bin.png │ │ ├── trash_bin_open.png │ │ ├── compost_bin_closed.png │ │ ├── compost_bin_open.png │ │ ├── recycling_bin_open.png │ │ ├── trash_bin_closed.png │ │ └── recycling_bin_closed.png ├── homegrid_embeds.pkl ├── __init__.py ├── benchmark.py ├── wrappers.py ├── window.py ├── homegrid_sentences.txt ├── manual_control.py ├── rendering.py ├── layout.py ├── language_wrappers.py ├── homegrid_base.py └── base.py ├── pyproject.toml ├── LICENSE ├── scripts └── embed_offline.py ├── .gitignore └── README.md /banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/banner.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | if __name__ == "__main__": 4 | setuptools.setup() 5 | -------------------------------------------------------------------------------- /homegrid/assets/chair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/chair.png -------------------------------------------------------------------------------- /homegrid/assets/egg_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/egg_0.png -------------------------------------------------------------------------------- /homegrid/assets/egg_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/egg_1.png -------------------------------------------------------------------------------- /homegrid/assets/egg_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/egg_2.png -------------------------------------------------------------------------------- /homegrid/assets/egg_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/egg_3.png -------------------------------------------------------------------------------- /homegrid/assets/lamp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/lamp.png -------------------------------------------------------------------------------- /homegrid/assets/plant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/plant.png -------------------------------------------------------------------------------- /homegrid/assets/robot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/robot.png -------------------------------------------------------------------------------- /homegrid/assets/rugl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/rugl.png -------------------------------------------------------------------------------- /homegrid/assets/rugr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/rugr.png -------------------------------------------------------------------------------- /homegrid/assets/sofa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/sofa.png -------------------------------------------------------------------------------- /homegrid/assets/stove.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/stove.png -------------------------------------------------------------------------------- /homegrid/assets/table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/table.png -------------------------------------------------------------------------------- /homegrid/assets/tile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/tile.png -------------------------------------------------------------------------------- /homegrid/assets/wood.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/wood.png -------------------------------------------------------------------------------- /homegrid/assets/cabinet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/cabinet.png -------------------------------------------------------------------------------- /homegrid/assets/carpet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/carpet.png -------------------------------------------------------------------------------- /homegrid/assets/chairl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/chairl.png -------------------------------------------------------------------------------- /homegrid/assets/chairr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/chairr.png -------------------------------------------------------------------------------- /homegrid/assets/cupboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/cupboard.png -------------------------------------------------------------------------------- /homegrid/assets/fridge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/fridge.png -------------------------------------------------------------------------------- /homegrid/assets/tomato_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/tomato_0.png -------------------------------------------------------------------------------- /homegrid/assets/tomato_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/tomato_1.png -------------------------------------------------------------------------------- /homegrid/assets/tomato_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/tomato_2.png -------------------------------------------------------------------------------- /homegrid/homegrid_embeds.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/homegrid_embeds.pkl -------------------------------------------------------------------------------- /homegrid/assets/countertop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/countertop.png -------------------------------------------------------------------------------- /homegrid/assets/sofa_side.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/sofa_side.png -------------------------------------------------------------------------------- /homegrid/assets/coffeetable.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/coffeetable.png -------------------------------------------------------------------------------- /homegrid/assets/objects/fruit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/objects/fruit.png -------------------------------------------------------------------------------- /homegrid/assets/bins/trash_bin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/bins/trash_bin.png -------------------------------------------------------------------------------- /homegrid/assets/objects/bottle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/objects/bottle.png -------------------------------------------------------------------------------- /homegrid/assets/objects/papers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/objects/papers.png -------------------------------------------------------------------------------- /homegrid/assets/objects/plates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/objects/plates.png -------------------------------------------------------------------------------- /homegrid/assets/bins/compost_bin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/bins/compost_bin.png -------------------------------------------------------------------------------- /homegrid/assets/bins/recycling_bin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/bins/recycling_bin.png -------------------------------------------------------------------------------- /homegrid/assets/bins/trash_bin_open.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/bins/trash_bin_open.png -------------------------------------------------------------------------------- /homegrid/assets/bins/compost_bin_closed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/bins/compost_bin_closed.png -------------------------------------------------------------------------------- /homegrid/assets/bins/compost_bin_open.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/bins/compost_bin_open.png -------------------------------------------------------------------------------- /homegrid/assets/bins/recycling_bin_open.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/bins/recycling_bin_open.png -------------------------------------------------------------------------------- /homegrid/assets/bins/trash_bin_closed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/bins/trash_bin_closed.png -------------------------------------------------------------------------------- /homegrid/assets/bins/recycling_bin_closed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlin816/homegrid/HEAD/homegrid/assets/bins/recycling_bin_closed.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "homegrid" 3 | version = "0.1.1" 4 | description = "A minimal home gridworld environment to test how agents use language hints." 5 | authors = ["Jessy Lin "] 6 | readme = "README.md" 7 | keywords = ["environment", "agent", "rl", "language"] 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.8" 11 | gym = { version = "0.26" } 12 | numpy = "*" 13 | matplotlib = "*" 14 | tokenizers = "*" 15 | sentencepiece = "*" 16 | transformers = { version = "*", optional = true } 17 | torch = { version = "*", optional = true } 18 | 19 | [tool.poetry.extras] 20 | # For pre-embedding new sentences 21 | dev = ["transformers", "torch"] 22 | 23 | [build-system] 24 | requires = ["poetry-core"] 25 | build-backend = "poetry.core.masonry.api" 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Jessy Lin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /homegrid/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | from homegrid.homegrid_base import HomeGridBase 3 | from homegrid.language_wrappers import MultitaskWrapper, LanguageWrapper 4 | from homegrid.wrappers import RGBImgPartialObsWrapper, FilterObsWrapper 5 | 6 | import warnings 7 | warnings.filterwarnings("ignore", module="gym.utils.passive_env_checker") 8 | warnings.filterwarnings("ignore", module="gym.spaces.box") 9 | 10 | class HomeGrid: 11 | 12 | def __init__(self, lang_types, *args, **kwargs): 13 | env = HomeGridBase(*args, **kwargs) 14 | env = RGBImgPartialObsWrapper(env) 15 | env = FilterObsWrapper(env, ["image"]) 16 | env = MultitaskWrapper(env) 17 | env = LanguageWrapper( 18 | env, 19 | preread_max=28, 20 | repeat_task_every=20, 21 | p_language=0.2, 22 | lang_types=lang_types, 23 | ) 24 | self.env = env 25 | 26 | def __getattr__(self, name): 27 | return getattr(self.env, name) 28 | 29 | def reset(self): 30 | return self.env.reset() 31 | 32 | def step(self, action): 33 | return self.env.step(action) 34 | 35 | register( 36 | id="homegrid-task", 37 | entry_point="homegrid:HomeGrid", 38 | kwargs={"lang_types": ["task"]}, 39 | ) 40 | 41 | register( 42 | id="homegrid-future", 43 | entry_point="homegrid:HomeGrid", 44 | kwargs={"lang_types": ["task", "future"]}, 45 | ) 46 | 47 | register( 48 | id="homegrid-dynamics", 49 | entry_point="homegrid:HomeGrid", 50 | kwargs={"lang_types": ["task", "dynamics"]}, 51 | ) 52 | 53 | register( 54 | id="homegrid-corrections", 55 | entry_point="homegrid:HomeGrid", 56 | kwargs={"lang_types": ["task", "corrections"]} 57 | ) 58 | -------------------------------------------------------------------------------- /homegrid/benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import time 4 | 5 | import gym 6 | 7 | from homegrid.wrappers import RGBImgPartialObsWrapper 8 | 9 | 10 | def benchmark(env_id, num_resets, num_frames): 11 | env = gym.make(env_id, disable_env_checker=True) 12 | # Benchmark env.reset 13 | t0 = time.time() 14 | for i in range(num_resets): 15 | env.reset() 16 | t1 = time.time() 17 | dt = t1 - t0 18 | reset_time = (1000 * dt) / num_resets 19 | 20 | # Benchmark rendering 21 | t0 = time.time() 22 | for i in range(num_frames): 23 | env.render() 24 | t1 = time.time() 25 | dt = t1 - t0 26 | frames_per_sec = num_frames / dt 27 | 28 | # Create an environment with an RGB agent observation 29 | env = gym.make(env_id, disable_env_checker=True) 30 | env = RGBImgPartialObsWrapper(env) 31 | 32 | env.reset() 33 | # Benchmark rendering 34 | t0 = time.time() 35 | for i in range(num_frames): 36 | obs, reward, terminated, truncated, info = env.step(0) 37 | t1 = time.time() 38 | dt = t1 - t0 39 | agent_view_fps = num_frames / dt 40 | 41 | print(f"Env reset time: {reset_time:.1f} ms") 42 | print(f"Rendering FPS : {frames_per_sec:.0f}") 43 | print(f"Agent view FPS: {agent_view_fps:.0f}") 44 | 45 | env.close() 46 | 47 | 48 | if __name__ == "__main__": 49 | import argparse 50 | 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument( 53 | "--env-id", 54 | dest="env_id", 55 | help="gym environment to load", 56 | default="homegrid-task", 57 | ) 58 | parser.add_argument("--num_resets", default=200) 59 | parser.add_argument("--num_frames", default=5000) 60 | args = parser.parse_args() 61 | benchmark(args.env_id, args.num_resets, args.num_frames) 62 | -------------------------------------------------------------------------------- /homegrid/wrappers.py: -------------------------------------------------------------------------------- 1 | from gym.core import Wrapper, ObservationWrapper 2 | import gym 3 | from gym import spaces 4 | from typing import List, Union 5 | import numpy as np 6 | 7 | class Gym26Wrapper(Wrapper): 8 | """Wraps gym v0.26 env with a ~v0.22 API so it can be used with sb3. 9 | 10 | Refer to gym wrappers for opposite compatibility (old env -> new API): 11 | https://github.com/openai/gym/blob/master/gym/wrappers/compatibility.py 12 | """ 13 | 14 | def reset(self, **kwargs): 15 | obs, info = self.env.reset(**kwargs) 16 | return obs 17 | 18 | def step(self, action): 19 | obs, reward, terminated, truncated, info = self.env.step(action) 20 | done = terminated or truncated 21 | return obs, reward, done, info 22 | 23 | def render(self, mode): 24 | return self.env.render() 25 | 26 | class FilterObsWrapper(ObservationWrapper): 27 | 28 | def __init__(self, env, obs_keys: List[str]): 29 | super().__init__(env) 30 | self.obs_keys = obs_keys 31 | self.observation_space = spaces.Dict({ 32 | k: v for k, v in env.observation_space.items() if k in self.obs_keys 33 | }) 34 | 35 | def observation(self, obs): 36 | return {k: v for k, v in obs.items() if k in self.obs_keys} 37 | 38 | class RGBImgPartialObsWrapper(ObservationWrapper): 39 | """RGBImg wrapper that also preserves the original symbolic observation.""" 40 | 41 | def __init__(self, env, tile_size=32): 42 | super().__init__(env) 43 | 44 | # Rendering attributes for observations 45 | self.tile_size = tile_size 46 | 47 | obs_shape = env.observation_space.spaces["image"].shape 48 | new_image_space = spaces.Box( 49 | low=0, 50 | high=255, 51 | shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3), 52 | dtype="uint8", 53 | ) 54 | 55 | self.observation_space = spaces.Dict( 56 | {**self.observation_space.spaces, "image": new_image_space} 57 | ) 58 | 59 | def observation(self, obs): 60 | rgb_img_partial = self.get_frame(tile_size=self.tile_size, agent_pov=True) 61 | 62 | return {**obs, "symbolic_image": obs["image"], "image": rgb_img_partial} 63 | -------------------------------------------------------------------------------- /scripts/embed_offline.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from transformers import T5Tokenizer, T5EncoderModel 3 | import sys 4 | import torch 5 | 6 | def embed_t5(sentences, outf): 7 | tokenizer = T5Tokenizer.from_pretrained("t5-small", use_legacy=False) 8 | model = T5EncoderModel.from_pretrained("t5-small") 9 | embed_cache = {} 10 | token_cache = {} 11 | lens = [] 12 | for s in sentences: 13 | s = s.strip() 14 | tokens = tokenizer(s, return_tensors="pt", add_special_tokens=False) 15 | with torch.no_grad(): 16 | embed = model(**tokens).last_hidden_state.squeeze(0).cpu().numpy() 17 | token_cache[s] = tokens["input_ids"].cpu().numpy()[0] 18 | embed_cache[s] = embed 19 | lens.append(len(token_cache[s])) 20 | 21 | token_cache[""] = tokenizer.pad_token_id 22 | with torch.no_grad(): 23 | embed_cache[""] = model( 24 | **tokenizer("", add_special_tokens=False, return_tensors="pt") 25 | ).last_hidden_state[0][0].cpu().numpy() 26 | assert embed_cache[""].shape == (model.config.d_model,), \ 27 | embed_cache[""].shape 28 | with open(outf, "wb") as f: 29 | pickle.dump((token_cache, embed_cache), f) 30 | print(lens) 31 | print("max len: ", max(lens)) 32 | 33 | def embed_st(sentences, outf): 34 | from sentence_transformers import SentenceTransformer 35 | model = SentenceTransformer("all-distilroberta-v1") 36 | embeds = {} 37 | for sent in sentences: 38 | embeds[sent] = model.encode([sent])[0] 39 | embeds[""]= model.encode([""])[0] 40 | assert embeds[""].shape == (768,), \ 41 | embeds[""].shape 42 | with open(outf, "wb") as f: 43 | pickle.dump((None, embeds), f) 44 | 45 | if __name__ == "__main__": 46 | import argparse 47 | 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument( 50 | "--infile", help="file with strings to embed", 51 | default="homegrid/homegrid_sentences.txt" 52 | ) 53 | parser.add_argument( 54 | "--outfile", help="file to output token and embedding cache", 55 | default="homegrid/homegrid_embeds.pkl" 56 | ) 57 | 58 | parser.add_argument( 59 | "--model", 60 | help="model to use for embedding (t5 or sentence)", 61 | choices=["t5", "sentence"], 62 | default="t5", 63 | ) 64 | 65 | args = parser.parse_args() 66 | with open(args.infile) as f: 67 | sentences = [] 68 | for line in f: 69 | sentences.append(line.strip()) 70 | 71 | if args.model == "t5": 72 | embed_t5(sentences, args.outfile) 73 | elif args.model == "st": 74 | embed_st(sentences, args.outfile) 75 | -------------------------------------------------------------------------------- /homegrid/window.py: -------------------------------------------------------------------------------- 1 | # Only ask users to install matplotlib if they actually need it 2 | try: 3 | import matplotlib.pyplot as plt 4 | except ImportError: 5 | raise ImportError( 6 | "To display the environment in a window, please install matplotlib, eg: `pip3 install --user matplotlib`" 7 | ) 8 | 9 | 10 | class Window: 11 | """ 12 | Window to draw a gridworld instance using Matplotlib 13 | """ 14 | 15 | def __init__(self, title): 16 | self.no_image_shown = True 17 | 18 | # Create the figure and axes 19 | self.fig, self.ax = plt.subplots() 20 | 21 | # Show the env name in the window title 22 | self.fig.canvas.manager.set_window_title(title) 23 | 24 | # Turn off x/y axis numbering/ticks 25 | self.ax.xaxis.set_ticks_position("none") 26 | self.ax.yaxis.set_ticks_position("none") 27 | _ = self.ax.set_xticklabels([]) 28 | _ = self.ax.set_yticklabels([]) 29 | 30 | # Flag indicating the window was closed 31 | self.closed = False 32 | 33 | def close_handler(evt): 34 | self.closed = True 35 | 36 | self.fig.canvas.mpl_connect("close_event", close_handler) 37 | 38 | def show_img(self, img): 39 | """ 40 | Show an image or update the image being shown 41 | """ 42 | 43 | # If no image has been shown yet, 44 | # show the first image of the environment 45 | if self.no_image_shown: 46 | self.imshow_obj = self.ax.imshow(img, interpolation="bilinear") 47 | self.no_image_shown = False 48 | # Update the image data 49 | self.imshow_obj.set_data(img) 50 | 51 | # Request the window be redrawn 52 | self.fig.canvas.draw_idle() 53 | self.fig.canvas.flush_events() 54 | 55 | # Let matplotlib process UI events 56 | plt.pause(0.001) 57 | 58 | def set_caption(self, text): 59 | """ 60 | Set/update the caption text below the image 61 | """ 62 | 63 | plt.xlabel(text) 64 | 65 | def reg_key_handler(self, key_handler): 66 | """ 67 | Register a keyboard event handler 68 | """ 69 | 70 | # Keyboard handler 71 | self.fig.canvas.mpl_connect("key_press_event", key_handler) 72 | 73 | def show(self, block=True): 74 | """ 75 | Show the window, and start an event loop 76 | """ 77 | 78 | # If not blocking, trigger interactive mode 79 | if not block: 80 | plt.ion() 81 | 82 | # Show the plot 83 | # In non-interative mode, this enters the matplotlib event loop 84 | # In interactive mode, this call does not block 85 | plt.show() 86 | 87 | def close(self): 88 | """ 89 | Close the window 90 | """ 91 | plt.close() 92 | self.closed = True 93 | -------------------------------------------------------------------------------- /homegrid/homegrid_sentences.txt: -------------------------------------------------------------------------------- 1 | go to the kitchen 2 | go to the living room 3 | go to the dining room 4 | find the bottle 5 | find the fruit 6 | find the papers 7 | find the plates 8 | find the recycling bin 9 | find the trash bin 10 | find the compost bin 11 | get the bottle 12 | get the fruit 13 | get the papers 14 | get the plates 15 | put the bottle in the recycling bin 16 | put the bottle in the trash bin 17 | put the bottle in the compost bin 18 | put the fruit in the recycling bin 19 | put the fruit in the trash bin 20 | put the fruit in the compost bin 21 | put the papers in the recycling bin 22 | put the papers in the trash bin 23 | put the papers in the compost bin 24 | put the plates in the recycling bin 25 | put the plates in the trash bin 26 | put the plates in the compost bin 27 | move the bottle to the kitchen 28 | move the bottle to the living room 29 | move the bottle to the dining room 30 | move the fruit to the kitchen 31 | move the fruit to the living room 32 | move the fruit to the dining room 33 | move the papers to the kitchen 34 | move the papers to the living room 35 | move the papers to the dining room 36 | move the plates to the kitchen 37 | move the plates to the living room 38 | move the plates to the dining room 39 | bottle is in the kitchen 40 | bottle is in the living room 41 | bottle is in the dining room 42 | fruit is in the kitchen 43 | fruit is in the living room 44 | fruit is in the dining room 45 | papers is in the kitchen 46 | papers is in the living room 47 | papers is in the dining room 48 | plates is in the kitchen 49 | plates is in the living room 50 | plates is in the dining room 51 | trash bin is in the kitchen 52 | trash bin is in the living room 53 | trash bin is in the dining room 54 | recycling bin is in the kitchen 55 | recycling bin is in the living room 56 | recycling bin is in the dining room 57 | compost bin is in the kitchen 58 | compost bin is in the living room 59 | compost bin is in the dining room 60 | i moved the bottle to the kitchen 61 | i moved the bottle to the living room 62 | i moved the bottle to the dining room 63 | i moved the fruit to the kitchen 64 | i moved the fruit to the living room 65 | i moved the fruit to the dining room 66 | i moved the papers to the kitchen 67 | i moved the papers to the living room 68 | i moved the papers to the dining room 69 | i moved the plates to the kitchen 70 | i moved the plates to the living room 71 | i moved the plates to the dining room 72 | there will be bottle in the kitchen later 73 | there will be bottle in the living room later 74 | there will be bottle in the dining room later 75 | there will be fruit in the kitchen later 76 | there will be fruit in the living room later 77 | there will be fruit in the dining room later 78 | there will be papers in the kitchen later 79 | there will be papers in the living room later 80 | there will be papers in the dining room later 81 | there will be plates in the kitchen later 82 | there will be plates in the living room later 83 | there will be plates in the dining room later 84 | pedal to open the trash bin 85 | grasp to open the trash bin 86 | lift to open the trash bin 87 | pedal to open the recycling bin 88 | grasp to open the recycling bin 89 | lift to open the recycling bin 90 | pedal to open the compost bin 91 | grasp to open the compost bin 92 | lift to open the compost bin 93 | open the trash bin 94 | open the trash bin 95 | open the trash bin 96 | open the recycling bin 97 | open the recycling bin 98 | open the recycling bin 99 | open the compost bin 100 | open the compost bin 101 | open the compost bin 102 | no, turn around 103 | spill near the trash bin 104 | spill near the recycling bin 105 | spill near the compost bin 106 | i cleaned the spill 107 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | **/__pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /homegrid/manual_control.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import gym 4 | 5 | from homegrid.window import Window 6 | import matplotlib.pyplot as plt 7 | from tokenizers import Tokenizer 8 | tok = Tokenizer.from_pretrained("t5-small") 9 | 10 | def redraw(window, img): 11 | window.show_img(img) 12 | 13 | 14 | def reset(env, window, seed=None, agent_view=False): 15 | obs, _ = env.reset() 16 | img = obs["image"] if agent_view else env.get_frame() 17 | redraw(window, img) 18 | 19 | 20 | def step(env, window, action, agent_view=False): 21 | obs, reward, terminated, truncated, info = env.step(action) 22 | print(info["symbolic_state"]) 23 | token = tok.decode([obs["token"]]) 24 | print(f"step={env.step_cnt}, reward={reward:.2f}") 25 | print("Token: ", token) 26 | print("Language: ", obs["log_language_info"] if "log_language_info" in obs else "None") 27 | print("Task: ", env.task) 28 | print("-"*20) 29 | window.set_caption( 30 | f"r={reward:.2f} token_id={obs['token']} token=" 31 | f"{token} \ncurrent: {obs['log_language_info'][:50]}...") 32 | 33 | if terminated: 34 | print(f"terminated! r={reward}") 35 | reset(env, window) 36 | elif truncated: 37 | print("truncated!") 38 | reset(env, window) 39 | else: 40 | img = obs["image"] if agent_view else env.get_frame() 41 | redraw(window, img) 42 | 43 | 44 | def key_handler(env, window, event, agent_view=False): 45 | print("pressed", event.key) 46 | step_ = lambda a: step(env, window, a, agent_view) 47 | 48 | if event.key == "escape": 49 | window.close() 50 | return 51 | 52 | if event.key == "backspace": 53 | reset(env, window) 54 | return 55 | 56 | if event.key == "left": 57 | step_(env.actions.left) 58 | return 59 | if event.key == "right": 60 | step_(env.actions.right) 61 | return 62 | if event.key == "up": 63 | step_(env.actions.up) 64 | return 65 | if event.key == "down": 66 | step_(env.actions.down) 67 | return 68 | 69 | if event.key == "k": 70 | step_(env.actions.pickup) 71 | return 72 | if event.key == "d": 73 | step_(env.actions.drop) 74 | return 75 | if event.key == "g": 76 | step_(env.actions.get) 77 | return 78 | if event.key == "p": 79 | step_(env.actions.pedal) 80 | return 81 | if event.key == "r": 82 | step_(env.actions.grasp) 83 | return 84 | if event.key == "l": 85 | step_(env.actions.lift) 86 | return 87 | 88 | 89 | if __name__ == "__main__": 90 | import argparse 91 | 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument( 94 | "--env", help="gym environment to load", default="homegrid-task" 95 | ) 96 | parser.add_argument( 97 | "--seed", 98 | type=int, 99 | help="random seed to generate the environment with", 100 | default=-1, 101 | ) 102 | parser.add_argument( 103 | "--tile_size", type=int, help="size at which to render tiles", default=32 104 | ) 105 | parser.add_argument( 106 | "--agent_view", 107 | default=False, 108 | help="draw the agent sees (partially observable view)", 109 | action="store_true", 110 | ) 111 | 112 | args = parser.parse_args() 113 | env = gym.make(args.env, disable_env_checker=True) 114 | 115 | for k in plt.rcParams: 116 | if "keymap" in k: 117 | plt.rcParams[k] = [] 118 | window = Window("homegrid - " + args.env) 119 | 120 | window.reg_key_handler(lambda event: key_handler(env, window, event, 121 | args.agent_view)) 122 | 123 | seed = None if args.seed == -1 else args.seed 124 | reset(env, window, seed, args.agent_view) 125 | 126 | # Blocking event loop 127 | window.show(block=True) 128 | -------------------------------------------------------------------------------- /homegrid/rendering.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | from PIL import Image 5 | import pathlib 6 | 7 | 8 | def downsample(img, factor): 9 | """ 10 | Downsample an image along both dimensions by some factor 11 | """ 12 | 13 | assert img.shape[0] % factor == 0 14 | assert img.shape[1] % factor == 0 15 | 16 | img = img.reshape( 17 | [img.shape[0] // factor, factor, img.shape[1] // factor, factor, 3] 18 | ) 19 | img = img.mean(axis=3) 20 | img = img.mean(axis=1) 21 | 22 | return img 23 | 24 | def resize(img, size): 25 | image = Image.fromarray(img) 26 | image = image.resize(size[::-1], resample=Image.NEAREST) 27 | image = np.array(image) 28 | return image 29 | 30 | def draw_obj(bg, obj): 31 | obj = resize(obj, (bg.shape[1], bg.shape[0])) 32 | if obj.shape[-1] == 3: 33 | obj = np.concatenate((obj, 255 * np.ones_like(obj[..., :1])), 34 | axis=-1) 35 | bg = bg[..., :3] 36 | bg[:] = (obj[:, :, 3:] == 255).astype(int) * obj[:, :, :3] + \ 37 | (obj[:, :, 3:] == 0).astype(int) * bg 38 | return bg 39 | 40 | def fill_coords(img, fn, color): 41 | """ 42 | Fill pixels of an image with coordinates matching a filter function 43 | """ 44 | for y in range(img.shape[0]): 45 | for x in range(img.shape[1]): 46 | yf = (y + 0.5) / img.shape[0] 47 | xf = (x + 0.5) / img.shape[1] 48 | if fn(xf, yf): 49 | img[y, x] = color 50 | 51 | return img 52 | 53 | 54 | def rotate_fn(fin, cx, cy, theta): 55 | def fout(x, y): 56 | x = x - cx 57 | y = y - cy 58 | 59 | x2 = cx + x * math.cos(-theta) - y * math.sin(-theta) 60 | y2 = cy + y * math.cos(-theta) + x * math.sin(-theta) 61 | 62 | return fin(x2, y2) 63 | 64 | return fout 65 | 66 | 67 | def point_in_line(x0, y0, x1, y1, r): 68 | p0 = np.array([x0, y0], dtype=np.float32) 69 | p1 = np.array([x1, y1], dtype=np.float32) 70 | dir = p1 - p0 71 | dist = np.linalg.norm(dir) 72 | dir = dir / dist 73 | 74 | xmin = min(x0, x1) - r 75 | xmax = max(x0, x1) + r 76 | ymin = min(y0, y1) - r 77 | ymax = max(y0, y1) + r 78 | 79 | def fn(x, y): 80 | # Fast, early escape test 81 | if x < xmin or x > xmax or y < ymin or y > ymax: 82 | return False 83 | 84 | q = np.array([x, y]) 85 | pq = q - p0 86 | 87 | # Closest point on line 88 | a = np.dot(pq, dir) 89 | a = np.clip(a, 0, dist) 90 | p = p0 + a * dir 91 | 92 | dist_to_line = np.linalg.norm(q - p) 93 | return dist_to_line <= r 94 | 95 | return fn 96 | 97 | 98 | def point_in_circle(cx, cy, r): 99 | def fn(x, y): 100 | return (x - cx) * (x - cx) + (y - cy) * (y - cy) <= r * r 101 | 102 | return fn 103 | 104 | 105 | def point_in_rect(xmin, xmax, ymin, ymax): 106 | def fn(x, y): 107 | return x >= xmin and x <= xmax and y >= ymin and y <= ymax 108 | 109 | return fn 110 | 111 | 112 | def point_in_triangle(a, b, c): 113 | a = np.array(a, dtype=np.float32) 114 | b = np.array(b, dtype=np.float32) 115 | c = np.array(c, dtype=np.float32) 116 | 117 | def fn(x, y): 118 | v0 = c - a 119 | v1 = b - a 120 | v2 = np.array((x, y)) - a 121 | 122 | # Compute dot products 123 | dot00 = np.dot(v0, v0) 124 | dot01 = np.dot(v0, v1) 125 | dot02 = np.dot(v0, v2) 126 | dot11 = np.dot(v1, v1) 127 | dot12 = np.dot(v1, v2) 128 | 129 | # Compute barycentric coordinates 130 | inv_denom = 1 / (dot00 * dot11 - dot01 * dot01) 131 | u = (dot11 * dot02 - dot01 * dot12) * inv_denom 132 | v = (dot00 * dot12 - dot01 * dot02) * inv_denom 133 | 134 | # Check if point is in triangle 135 | return (u >= 0) and (v >= 0) and (u + v) < 1 136 | 137 | return fn 138 | 139 | 140 | def highlight_img(img, color=(255, 255, 255), alpha=0.30): 141 | """ 142 | Add highlighting to an image 143 | """ 144 | 145 | blend_img = img + alpha * (np.array(color, dtype=np.uint8) - img) 146 | blend_img = blend_img.clip(0, 255).astype(np.uint8) 147 | img[:, :, :] = blend_img 148 | -------------------------------------------------------------------------------- /homegrid/layout.py: -------------------------------------------------------------------------------- 1 | """Store room layouts and assets.""" 2 | 3 | from collections import defaultdict 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from PIL import Image 7 | import numpy as np 8 | 9 | from homegrid.base import ( 10 | Wall, Storage, Inanimate, Pickable, FloorWithObject 11 | ) 12 | from homegrid.rendering import draw_obj 13 | 14 | @dataclass 15 | class RoomSpec: 16 | name: str 17 | texture: str 18 | 19 | 20 | ROOMS = { 21 | "K": RoomSpec("kitchen", "tile"), 22 | "L": RoomSpec("living room", "carpet"), 23 | "D": RoomSpec("dining room", "wood"), 24 | } 25 | room2name = {k: v.name for k, v in ROOMS.items()} 26 | # All static and interactive objects 27 | OBJECTS = [ 28 | "cupboard", 29 | "stove", 30 | "fridge", 31 | "countertop", 32 | "chairl", 33 | "chairr", 34 | "table", 35 | "sofa", 36 | "sofa_side", 37 | "rugl", 38 | "rugr", 39 | "coffeetable", 40 | "cabinet", 41 | "plant", 42 | ] 43 | # Objects that can have things placed on top of them 44 | SURFACES = { 45 | "cupboard", 46 | "stove", 47 | "countertop", 48 | "chairl", 49 | "chairr", 50 | "table", 51 | "sofa", 52 | "sofa_side", 53 | "rugl", 54 | "rugr", 55 | "coffeetable", 56 | } 57 | # Trash objects 58 | TRASH = [ 59 | "bottle", 60 | "fruit", 61 | "papers", 62 | "plates", 63 | ] 64 | # Trash receptacles 65 | CANS = [ 66 | "recycling_bin", 67 | "trash_bin", 68 | "compost_bin", 69 | ] 70 | 71 | 72 | class ThreeRoom: 73 | LAYOUT = { 74 | # Room layout, keys into ROOMS 75 | "rooms": 76 | """......WWWWWWWW 77 | ......WLLLLLLW 78 | ......WLLLLLLW 79 | ......WLLLLLLW 80 | ......WLLLLLLW 81 | WWWWWWWWLLLLLW 82 | WKKKKKKWDDDDDW 83 | WKKKKKKKDDDDDW 84 | WKKKKKKKDDDDDW 85 | WKKKKKKKDDDDDW 86 | WKKKKKKWDDDDDW 87 | WWWWWWWWWWWWWW 88 | """.splitlines(), 89 | # Static (non-interactive) objects, keys into OBJECTS 90 | "fixed_objects": 91 | """......WWWWWWWW 92 | ......WnhlhLmW 93 | ......WiLLLLLW 94 | ......WiLjkLLW 95 | ......WLLLLLLW 96 | WWWWWWWWLLLLLW 97 | WaabbacWDDDDDW 98 | WKKKKKKKDDDDDW 99 | WKKKKKKKDegfDW 100 | WKKKKKKKDegfDW 101 | WddddddWDDDDDW 102 | WWWWWWWWWWWWWW 103 | """.splitlines(), 104 | # Valid positions to place agent and objects 105 | "valid_poss": { 106 | "agent_start": 107 | """.............. 108 | .............. 109 | ........xxxxx. 110 | ........xxxxx. 111 | ........xxxxx. 112 | ........xxxxx. 113 | ........xxxxx. 114 | .xxxxxxxxxxxx. 115 | .xxxxxxxx...x. 116 | .xxxxxxxx...x. 117 | ........xxxxx. 118 | .............. 119 | """.splitlines(), 120 | "obj": 121 | """.............. 122 | .........x.... 123 | ........x...x. 124 | .............. 125 | .......x...... 126 | .............. 127 | ..x..x........ 128 | ............. 129 | ..........x... 130 | ..........x... 131 | ..xx.......... 132 | .............. 133 | """.splitlines(), 134 | "can": 135 | """.............. 136 | ...........x.. 137 | .............. 138 | .............. 139 | .............. 140 | .............. 141 | .............. 142 | .............. 143 | .............. 144 | .x............ 145 | ............x. 146 | .............. 147 | """.splitlines(), 148 | }} 149 | 150 | def __init__(self): 151 | self._load_textures() 152 | self._parse_valid_poss() 153 | self.width = len(self.LAYOUT["rooms"][0]) 154 | self.height = len(self.LAYOUT["rooms"]) 155 | self.room_to_cells = defaultdict(list) 156 | self.cell_to_room = {} 157 | for y in range(len(self.LAYOUT["rooms"])): 158 | for x in range(len(self.LAYOUT["rooms"][y])): 159 | if self.LAYOUT["rooms"][y][x] != "." and self.LAYOUT["rooms"][y][x] != "W": 160 | room_code = self.LAYOUT["rooms"][y][x] 161 | self.room_to_cells[room_code].append((x, y)) 162 | self.cell_to_room[(x, y)] = room_code 163 | 164 | def _load_textures(self): 165 | textures = {} 166 | for fname in Path(__file__).parent.glob("assets/**/*.png"): 167 | name = fname.stem 168 | textures[name] = np.asarray(Image.open(fname)) 169 | self.textures = textures 170 | 171 | def _parse_valid_poss(self): 172 | valid_poss = defaultdict(list) 173 | for k, grid in self.LAYOUT["valid_poss"].items(): 174 | for y, line in enumerate(grid): 175 | for x, c in enumerate(line): 176 | if c == "x": 177 | valid_poss[k].append((x, y)) 178 | self.valid_poss = valid_poss 179 | 180 | def populate(self, grid): 181 | """Fill the grid with the floor and fixed object layout.""" 182 | for y in range(len(self.LAYOUT["rooms"])): 183 | for x in range(len(self.LAYOUT["rooms"][y])): 184 | if self.LAYOUT["rooms"][y][x] == ".": 185 | continue 186 | elif self.LAYOUT["rooms"][y][x] == 'W': 187 | grid.set(x, y, Wall()) 188 | else: 189 | room_code = self.LAYOUT["rooms"][y][x] 190 | assert room_code in ROOMS.keys(),\ 191 | "Invalid room code: {}".format(room_code) 192 | floor = ROOMS[room_code].texture 193 | name = floor 194 | agent_can_overlap = True 195 | can_overlap = True 196 | texture = self.textures[floor].copy() 197 | if self.LAYOUT["fixed_objects"][y][x].islower(): 198 | item = OBJECTS[ord(self.LAYOUT["fixed_objects"][y][x]) - 97] 199 | name += f"_{item}" 200 | agent_can_overlap = False 201 | can_overlap = (item in SURFACES) 202 | texture = draw_obj(texture, self.textures[item].copy()) 203 | grid.set_floor(x, y, 204 | FloorWithObject(name, texture, 205 | agent_can_overlap=agent_can_overlap, 206 | can_overlap=can_overlap)) 207 | 208 | 209 | 210 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![](https://github.com/jlin816/homegrid/raw/main/banner.png) 2 | 3 |

4 | A minimal home grid world environment to evaluate language understanding in interactive agents. 5 |

6 | 7 | # 🏠 Getting Started 8 | 9 | Play as a human: 10 | ```bash 11 | pip install homegrid 12 | ./homegrid/manual_control.py 13 | ``` 14 | 15 | Use as a gym environment: 16 | ```python 17 | import gym 18 | import homegrid 19 | env = gym.make("homegrid-task") 20 | ``` 21 | See `homegrid/__init__.py` for the environment configurations used in the paper [Learning to Model the World with Language](https://dynalang.github.io/). 22 | 23 | # 📑 Documentation 24 | 25 | HomeGrid tests whether agents can learn to use language that provides information about the world. In addition to task instructions, the env provides scripted _language hints_, simulating knowledge that agents might learn from humans (e.g., in a collaborative setting) or read in text (e.g., on Wikipedia). Agents navigate around a house to find objects and interact with them to perform tasks, while learning how to understand language from experience. 26 | 27 | ### ⚡️ Quick Info 28 | - pixel observations (3x3 partial view of the house) 29 | - both one-hot and token embedding observations available 30 | - discrete action space (movement + object interaction) 31 | - 3 rooms, 7 objects (3 trash bins, 4 trash objects) 32 | - multitask with language instructions + hints 33 | - randomized object placement and object dynamics 34 | 35 | **Task Templates (38 total tasks):** 36 | - find the `object/bin`: the agent will receive a reward of 1 if it is facing the correct object / bin 37 | - get the `object`: the agent will receive a reward of 1 if it has the correct object in inventory 38 | - put the `object` in the `bin`: the agent will receive a reward of 1 if the bin contains the object 39 | - move the `object` to the `room`: the agent will receive a reward of 1 if the object is in the room 40 | - open the `bin`: the agent will receive a reward of 1 if the bin is in the open state 41 | 42 | **Language Types and Templates** 43 | 44 | - **Future Observations**: descriptions of what agents might observe in the future, such as "The plates are in the kitchen." 45 | - _"`object/bin` is in the `room`"_: the object or bin is in the indicated room 46 | - _"i moved the `object` to the `room`"_: the object has been moved to the room 47 | - _"there will be `object` in the `room`"_: the object will spawn in the room in five timesteps 48 | - **Dynamics**: descriptions of environment dynamics, such as "Pedal to open the compost bin." 49 | - _"`action` to open the `bin`"_: the indicated action is the correct action to open the bin 50 | - **Corrections**: interactive, task-specific feedback based on what the agent is currently doing, such as "Turn around." 51 | - _"no, turn around"_: the agent's distance to the current goal object or bin (given the task) has increased compared to the last timestep 52 | 53 | Environment instances are provided for task instruction + each of the types above in `homegrid/__init__.py`. 54 | 55 | Language is provided by `homegrid/language_wrappers.py` and streamed one token per timestep by default. Both token IDs and token embeddings are provided in the observation, using the [T5](https://arxiv.org/abs/1910.10683) tokenizer and encoder model. The original paper introducing HomeGrid uses the token IDs. 56 | Some strings are higher priority than others and may interrupt a string that is currently being read. By default, the environment will stream some hints that apply to a whole episode during the first timesteps, while the agent does not move. See `homegrid/language_wrappers.py` for details. 57 | 58 | ### Observation Space 59 | 60 | For the full HomeGrid environment with language: 61 | 62 | - `image (uint8 (96, 96, 3))`: pixel agent-centric local view 63 | - `token (int)`: T5 token ID of the token at the current timestep 64 | - `token_embed (float32 (512,))`: T5 embedding of the token at the current timestep 65 | - `is_read_step (bool)`: for logging, `True` if agent is reading strings before the episode begins 66 | - `log_language_info (str)`: for logging, human-readable text for the string currently being streamed 67 | 68 | # 💻 Development 69 | 70 | New development and extensions to the environment are welcome! 71 | 72 | ### Adding new language utterances 73 | 74 | Sentences are pre-embedded and cached into a file for training efficiency. You'll have to append the additional sentences to `homegrid/homegrid_sentences.txt` and re-generate the cached token and embedding file with the following command: 75 | ```bash 76 | python scripts/embed_offline.py \ 77 | --infile homegrid/homegrid_sentences.txt \ 78 | --outfile homegrid/homecook_embeds.pkl \ 79 | --model t5 80 | ``` 81 | 82 | ### Adding new layouts and objects 83 | 84 | HomeGrid currently has one layout and a fixed set of objects that are sampled to populate each episode. Many of the receptacles and containers (e.g. cabinets) are disabled for simplicity. 85 | 86 | To add new layouts, create a new class in `homegrid/layout.py`. 87 | 88 | To add new static (non-interactive) objects, add assets to `homegrid/assets.py` and then specify where they are rendered in the `homegrid/layout.py`. 89 | 90 | To add new interactive objects, additionally specify how they behave in `homegrid/homegrid_base.py:step`. 91 | 92 | # Acknowledgments 93 | 94 | HomeGrid is based on [MiniGrid](https://github.com/Farama-Foundation/Minigrid). 95 | The environment assets are thanks to [limezu](https://limezu.itch.io/) and [Mounir Tohami](https://mounirtohami.itch.io/). 96 | 97 | # Citation 98 | 99 | ``` 100 | @article{lin2023learning, 101 | title={Learning to Model the World with Language}, 102 | author={Jessy Lin and Yuqing Du and Olivia Watkins and Danijar Hafner and Pieter Abbeel and Dan Klein and Anca Dragan}, 103 | year={2023}, 104 | eprint={2308.01399}, 105 | archivePrefix={arXiv}, 106 | } 107 | ``` 108 | -------------------------------------------------------------------------------- /homegrid/language_wrappers.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, Counter 2 | from enum import Enum 3 | import os 4 | import random 5 | from typing import Dict, List 6 | import pathlib 7 | import pickle 8 | 9 | import gym 10 | from gym import spaces 11 | import numpy as np 12 | from tokenizers import Tokenizer 13 | 14 | from homegrid.base import Pickable, Storage 15 | from homegrid.layout import room2name 16 | 17 | class MultitaskWrapper(gym.Wrapper): 18 | """Continually sample tasks during an episode, rewarding the agent for 19 | completion.""" 20 | 21 | Tasks = Enum("Tasks", [ 22 | "find", "get", "cleanup", "rearrange", "open"], 23 | start=0) 24 | 25 | def __init__(self, env): 26 | super().__init__(env) 27 | self.tasks = list(MultitaskWrapper.Tasks) 28 | 29 | def sample_task(self): 30 | task_type = random.choice(self.tasks) 31 | 32 | if task_type == MultitaskWrapper.Tasks.find: 33 | obj_name = random.choice(self.env.objs).name 34 | task = f"find the {obj_name}" 35 | def reward_fn(symbolic_state): 36 | return int(symbolic_state["front_obj"] == obj_name) 37 | elif task_type == MultitaskWrapper.Tasks.get: 38 | obj_name = random.choice([ob for ob in self.env.objs if \ 39 | isinstance(ob, Pickable)]).name 40 | task = f"get the {obj_name}" 41 | def reward_fn(symbolic_state): 42 | return int(symbolic_state["agent"]["carrying"] == obj_name) 43 | elif task_type == MultitaskWrapper.Tasks.open: 44 | obj_name = random.choice([ob for ob in self.env.objs if \ 45 | isinstance(ob, Storage)]).name 46 | task = f"open the {obj_name}" 47 | def reward_fn(symbolic_state): 48 | for obj in symbolic_state["objects"]: 49 | if obj["name"] == obj_name: 50 | return int(obj["state"] == "open") 51 | elif task_type == MultitaskWrapper.Tasks.cleanup: 52 | obj_name = random.choice([ob for ob in self.env.objs if \ 53 | isinstance(ob, Pickable)]).name 54 | bin_name = random.choice([ob for ob in self.env.objs if \ 55 | isinstance(ob, Storage)]).name 56 | task = f"put the {obj_name} in the {bin_name}" 57 | def reward_fn(symbolic_state): 58 | if symbolic_state["agent"]["carrying"] == obj_name: 59 | return 0.5 60 | for obj in symbolic_state["objects"]: 61 | if obj["name"] == bin_name: 62 | return int(obj_name in obj["contains"]) 63 | elif task_type == MultitaskWrapper.Tasks.rearrange: 64 | room_code = random.choice(list(self.env.room_to_cells.keys())) 65 | obj_name = random.choice([ob for ob in self.env.objs if \ 66 | isinstance(ob, Pickable)]).name 67 | task = f"move the {obj_name} to the {room2name[room_code]}" 68 | def reward_fn(symbolic_state): 69 | if symbolic_state["agent"]["carrying"] == obj_name: 70 | return 0.5 71 | for obj in symbolic_state["objects"]: 72 | if obj["name"] == obj_name: 73 | return int(obj["room"] == room_code) 74 | else: 75 | raise ValueError(f"Unknown task type {task_type}") 76 | def dist_goal(symbolic_state): 77 | goal_name = obj_name 78 | if task_type == MultitaskWrapper.Tasks.cleanup: 79 | goal_name = bin_name if symbolic_state["agent"]["carrying"] == obj_name \ 80 | else obj_name 81 | pos = [o for o in symbolic_state["objects"] if \ 82 | o["name"] == goal_name][0]["pos"] 83 | return abs(self.agent_pos[0] - pos[0]) + abs(self.agent_pos[1] - pos[1]) 84 | 85 | self.task = task 86 | self.reward_fn = reward_fn 87 | self.dist_goal = dist_goal 88 | self.subtask_done = False 89 | self.start_step = self.step_cnt 90 | 91 | def reset(self): 92 | obs, info = self.env.reset() 93 | self.step_cnt = 0 94 | self.start_step = 0 95 | self.accomplished_tasks = [] 96 | self.task_times = [] 97 | self.sample_task() 98 | info.update({ 99 | "log_timesteps_with_task": self.step_cnt - self.start_step, 100 | "log_new_task": True, 101 | "log_dist_goal": self.dist_goal(info["symbolic_state"]) 102 | }) 103 | return obs, info 104 | 105 | def step(self, action): 106 | self.step_cnt += 1 107 | obs, rew, term, trunc, info = self.env.step(action) 108 | info.update({ 109 | "log_timesteps_with_task": self.step_cnt - self.start_step, 110 | "log_new_task": False, 111 | "log_dist_goal": self.dist_goal(info["symbolic_state"]) 112 | }) 113 | if term: 114 | return obs, rew, term, trunc, info 115 | rew = self.reward_fn(info["symbolic_state"]) 116 | if rew == 1: 117 | self.accomplished_tasks.append(self.task) 118 | self.task_times.append(self.step_cnt - self.start_step) 119 | self.sample_task() 120 | info.update({ 121 | "log_timesteps_with_task": self.step_cnt - self.start_step, 122 | "log_accomplished_tasks": self.accomplished_tasks, 123 | "log_task_times": self.task_times, 124 | "log_new_task": True, 125 | "log_dist_goal": self.dist_goal(info["symbolic_state"]) 126 | }) 127 | elif rew == 0.5: 128 | if self.subtask_done: rew = 0 # don't reward twice 129 | self.subtask_done = True 130 | return obs, rew, term, trunc, info 131 | 132 | class LanguageWrapper(gym.Wrapper): 133 | """Provide the agent with language information one token at a time, using underlying 134 | environment state and task wrapper. 135 | 136 | Configures types of language available, and specifies logic for which language is provided at 137 | a given step, if multiple strings are available.""" 138 | 139 | def __init__(self, 140 | env, 141 | # Max # tokens during prereading phase (for future/dynamics) 142 | preread_max=-1, 143 | # How often to repeat the task description 144 | repeat_task_every=20, 145 | # Prob of sampling descriptions when we don't have task language 146 | p_language=0.2, 147 | debug=False, 148 | lang_types=["task", "future", "dynamics", "corrections", "termination"], 149 | ): 150 | super().__init__(env) 151 | assert len(lang_types) >= 1 and "task" in lang_types, \ 152 | f"Must have task language, {lang_types}" 153 | for t in lang_types: 154 | assert t in ["task", "future", "dynamics", "corrections", "termination"], \ 155 | f"Unknown language type {t}" 156 | 157 | if "dynamics" in lang_types or "future" in lang_types: 158 | assert preread_max > -1, \ 159 | "Must have preread for dynamics/future language" 160 | 161 | self.instruction_only = len(lang_types) == 1 and lang_types[0] == "task" 162 | self.preread_max = preread_max 163 | self.repeat_task_every = repeat_task_every 164 | self.p_language = p_language 165 | self.debug = debug 166 | self.lang_types = lang_types 167 | self.preread = -1 if self.instruction_only else self.preread_max 168 | 169 | directory = pathlib.Path(__file__).resolve().parent 170 | with open(directory / "homegrid_embeds.pkl", "rb") as f: 171 | self.cache, self.embed_cache = pickle.load(f) 172 | self.empty_token = self.cache[""] 173 | # List of tokens of current utterance we're streaming 174 | self.tokens = [self.empty_token] 175 | # Index of self.tokens for current timestep 176 | self.cur_token = 0 177 | self.embed_size = 512 178 | self.observation_space = spaces.Dict({ 179 | **self.env.observation_space.spaces, 180 | "token": spaces.Box( 181 | 0, 32100, 182 | shape=(), 183 | dtype=np.uint32), 184 | "token_embed": spaces.Box( 185 | -np.inf, np.inf, 186 | shape=(self.embed_size,), 187 | dtype=np.float32), 188 | "is_read_step": spaces.Box( 189 | low=np.array(False), 190 | high=np.array(True), 191 | shape=(), 192 | dtype=bool), 193 | "log_language_info": spaces.Text( 194 | max_length=10000, 195 | ), 196 | }) 197 | if self.debug: 198 | self.tok = Tokenizer.from_pretrained("t5-small") 199 | 200 | def get_descriptions(self, state): 201 | # facts: 202 | # - object locations (beginning only but also anytime) 203 | # - irreversible state (don't change) 204 | # - dynamics (don't change) 205 | descs = [] 206 | for obj in state["objects"]: 207 | if "dynamics" in self.lang_types and obj["action"]: 208 | descs.append(f"{obj['action']} to open the {obj['name']}") 209 | if "future" in self.lang_types and obj["room"]: 210 | descs.append(f"{obj['name']} is in the {room2name[obj['room']]}") 211 | return descs 212 | 213 | def _tokenize(self, string): 214 | if string in self.cache: 215 | return self.cache[string] 216 | if self.debug: 217 | return self.tok(string, add_special_tokens=False)["input_ids"] 218 | raise NotImplementedError(f"tokenize, string not preembedded: >{string}<") 219 | 220 | def _embed(self, string): 221 | if string in self.embed_cache: 222 | return self.embed_cache[string] 223 | if self.debug: 224 | return [5555] * len(self.tokens) 225 | raise NotImplementedError(f"embed, string not preembedded: >{string}<") 226 | 227 | def _set_current_string(self, string_or_strings): 228 | if isinstance(string_or_strings, list): 229 | self.string = " ".join(string_or_strings) 230 | self.tokens = [x for string in string_or_strings \ 231 | for x in self._tokenize(string)] 232 | self.token_embeds = [x for string in string_or_strings \ 233 | for x in self._embed(string)] 234 | self.cur_token = 0 235 | elif isinstance(string_or_strings, str): 236 | string = string_or_strings 237 | self.string = string 238 | self.tokens = self._tokenize(string) 239 | self.token_embeds = self._embed(string) 240 | self.cur_token = 0 241 | 242 | def _increment_token(self): 243 | if self._lang_is_empty(): 244 | return 245 | self.cur_token += 1 246 | if self.cur_token == len(self.tokens): 247 | self.string = "" 248 | self.tokens = [self.empty_token] 249 | self.token_embeds = [self._embed(self.string)] 250 | self.cur_token = 0 251 | 252 | def _lang_is_empty(self): 253 | return self.string == "" 254 | 255 | def add_language_to_obs(self, obs, info): 256 | """Adds language keys to the observation: 257 | - token (int): current token 258 | - token_embed (np.array): embedding of current token 259 | - log_language_info (str): human-readable info about language 260 | 261 | On each step, either 262 | describe new task (will interrupt other language) 263 | continue tokens that are currently being streamed 264 | repeat task if it's time 265 | describe something that changed or will happen (events) 266 | describe a fact (if not preread) - TODO 267 | correct the agent - TODO 268 | """ 269 | if self._step_cnt >= self.preread and info["log_new_task"]: 270 | # on t=self._step_cnt, we will start streaming the new task 271 | self._set_current_string(self.env.task) 272 | self._last_task_repeat = self._step_cnt 273 | 274 | if self._lang_is_empty(): 275 | describable_evts = [e for e in info["events"] 276 | if e.get("type", "none") in self.lang_types] 277 | if self.repeat_task_every > 0 and \ 278 | self._step_cnt - self._last_task_repeat >= self.repeat_task_every: 279 | self._set_current_string(self.env.task) 280 | self._last_task_repeat = self._step_cnt 281 | elif len(describable_evts) > 0: 282 | evt = random.choice(describable_evts) 283 | self._set_current_string(evt["description"]) 284 | elif np.random.rand() < self.p_language: 285 | if "corrections" in self.lang_types and \ 286 | info["log_dist_goal"] > self.last_dist: 287 | self._set_current_string("no, turn around") 288 | else: 289 | descs = self.get_descriptions(info["symbolic_state"]) 290 | if len(descs) > 0: 291 | self._set_current_string(random.choice(descs)) 292 | 293 | obs.update({ 294 | "token": self.tokens[self.cur_token], 295 | "token_embed": self.token_embeds[self.cur_token], 296 | "log_language_info": self.string, 297 | }) 298 | self._increment_token() 299 | return obs 300 | 301 | def reset(self): 302 | obs, info = self.env.reset() 303 | obs["is_read_step"] = False 304 | self.last_dist = info["log_dist_goal"] 305 | if self.preread_max > -1: 306 | descs = self.get_descriptions(info["symbolic_state"]) 307 | random.shuffle(descs) 308 | self._set_current_string(descs) 309 | self.preread = min(len(self.tokens), self.preread_max) 310 | obs["image"] = obs["image"] // 2 311 | obs["is_read_step"] = True 312 | self.init_obs = obs 313 | self.init_info = info 314 | self._step_cnt = 0 315 | self._last_task_repeat = 0 316 | obs = self.add_language_to_obs(obs, info) 317 | return obs, info 318 | 319 | def step(self, action): 320 | self._step_cnt += 1 321 | if self._step_cnt <= self.preread: 322 | obs, rew, term, trunc, info = self.init_obs, 0, False, False, self.init_info 323 | obs["is_read_step"] = True 324 | else: 325 | obs, rew, term, trunc, info = self.env.step(action) 326 | obs["is_read_step"] = False 327 | obs = self.add_language_to_obs(obs, info) 328 | self.last_dist = info["log_dist_goal"] 329 | return obs, rew, term, trunc, info 330 | -------------------------------------------------------------------------------- /homegrid/homegrid_base.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, List, Tuple 2 | import gym 3 | from gym import spaces 4 | from collections import defaultdict 5 | from enum import IntEnum 6 | import random 7 | import numpy as np 8 | 9 | from PIL import Image 10 | from PIL import ImageFont 11 | from PIL import ImageDraw 12 | 13 | from homegrid.base import ( 14 | MiniGridEnv, Grid, 15 | Storage, Inanimate, Pickable 16 | ) 17 | from homegrid.layout import ThreeRoom, CANS, TRASH, room2name 18 | 19 | 20 | class HomeGridBase(MiniGridEnv): 21 | 22 | class Actions(IntEnum): 23 | left = 0 24 | right = 1 25 | up = 2 26 | down = 3 27 | # item actions 28 | pickup = 4 29 | drop = 5 30 | # storage actions 31 | get = 6 32 | pedal = 7 33 | grasp = 8 34 | lift = 9 35 | 36 | ac2dir = { 37 | Actions.right: 0, 38 | Actions.down: 1, 39 | Actions.left: 2, 40 | Actions.up: 3, 41 | } 42 | 43 | def __init__(self, 44 | layout=ThreeRoom, 45 | num_trashcans=2, 46 | num_trashobjs=2, 47 | view_size=3, 48 | max_steps=100, 49 | p_teleport=0.05, 50 | max_objects=4, 51 | p_unsafe=0.0, 52 | fixed_state=None, 53 | ): 54 | self.layout = layout() 55 | self.textures = self.layout.textures 56 | super().__init__( 57 | width=self.layout.width, 58 | height=self.layout.height, 59 | render_mode="rgb_array", 60 | agent_view_size=view_size, 61 | max_steps=max_steps, 62 | ) 63 | self.actions = HomeGridBase.Actions 64 | self.action_space = spaces.Discrete(len(self.actions)) 65 | self.num_trashcans = num_trashcans 66 | self.num_trashobjs = num_trashobjs 67 | self.p_teleport = p_teleport 68 | self.max_objects = max_objects 69 | self.p_unsafe = p_unsafe 70 | self.fixed_state = fixed_state 71 | 72 | @property 73 | def step_cnt(self): 74 | return self._step_cnt 75 | 76 | def init_from_state(self, state): 77 | """Initialize the env from a symbolic state.""" 78 | self._create_layout(self.width, self.height) 79 | # place agent 80 | self.agent_pos = state["agent"]["pos"] 81 | self.agent_dir = state["agent"]["dir"] 82 | self.objs = [] 83 | # place objects with appropriate state 84 | for ob in state["objects"]: 85 | if ob["pos"] == (-1, -1): 86 | print(f"Skipping carried object {ob['name']}") 87 | continue 88 | if ob["type"] == "Storage": 89 | pfx = ob["name"].replace(" ", "_") 90 | obj = Storage( 91 | name=ob["name"], 92 | textures={ 93 | "open": self.textures[f"{pfx}_open"], 94 | "closed": self.textures[f"{pfx}_closed"]}, 95 | state=ob["state"], 96 | action=ob["action"], 97 | contains=ob["contains"], 98 | ) 99 | elif ob["type"] == "Pickable": 100 | obj = Pickable( 101 | name=ob["name"], 102 | texture=self.textures[ob["name"]], 103 | invisible=ob["invisible"], 104 | ) 105 | else: 106 | raise NotImplementedError("Obj type {ob['type']}") 107 | self.objs.append(obj) 108 | self.place_obj(obj, top=ob["pos"], 109 | size=(1,1), max_tries=1) 110 | 111 | def _add_cans_to_house(self): 112 | cans = random.sample(CANS, self.num_trashcans) 113 | poss = random.sample(self.layout.valid_poss["can"], self.num_trashcans) 114 | can_objs = [] 115 | for i, can in enumerate(cans): 116 | obj = Storage(can, { 117 | "open": self.textures[f"{can}_open"], 118 | "closed": self.textures[f"{can}_closed"]}, 119 | # Make one of the cans irreversibly broken 120 | reset_broken_after=200 if i == 0 else 5) 121 | pos = self.place_obj(obj, top=poss[i], size=(1,1), max_tries=5) 122 | can_objs.append(obj) 123 | self.objs.append(obj) 124 | 125 | def _add_objs_to_house(self): 126 | trash_objs = random.sample(TRASH, self.num_trashobjs) 127 | poss = random.sample(self.layout.valid_poss["obj"], self.num_trashobjs) 128 | trashobj_objs = [] 129 | for i, trash in enumerate(trash_objs): 130 | obj = Pickable(trash, self.textures[trash]) 131 | pos = self.place_obj(obj, top=poss[i], size=(1,1), max_tries=5) 132 | trashobj_objs.append(obj) 133 | self.objs.append(obj) 134 | 135 | def _gen_grid(self, width, height): 136 | if self.fixed_state: 137 | print("Initializing from fixed state") 138 | self.init_from_state(self.fixed_state) 139 | return 140 | regenerate = True 141 | while regenerate: 142 | self._create_layout(width, height) 143 | regenerate = False 144 | 145 | self.objs = [] 146 | self.goal = {"obj": None, "can": None} 147 | # Place objects 148 | self._add_cans_to_house() 149 | self._add_objs_to_house() 150 | 151 | # Place agent 152 | agent_poss = random.choice(self.layout.valid_poss["agent_start"]) 153 | self.agent_pos = self.place_agent(top=agent_poss, size=(1, 1)) 154 | 155 | def _create_layout(self, width, height): 156 | # Create grid with surrounding walls 157 | self.grid = Grid(width, height) 158 | self.layout.populate(self.grid) 159 | self.room_to_cells = self.layout.room_to_cells 160 | self.cell_to_room = self.layout.cell_to_room 161 | 162 | def _maybe_teleport(self): 163 | if np.random.random() > self.p_teleport: 164 | return False 165 | objs = [o for o in self.objs if isinstance(o, Pickable) and o.cur_pos[0] != 166 | -1 and not o.invisible] 167 | if len(objs) == 0: 168 | print(self.objs) 169 | print([o.cur_pos for o in self.objs]) 170 | print(self.all_events) 171 | return False 172 | obj = random.choice(objs) 173 | # Choose a random new location with no object to place this 174 | poss = random.choice([ 175 | pos for pos in self.layout.valid_poss["obj"] \ 176 | if pos != obj.cur_pos and self.grid.get(*pos) is None \ 177 | and pos != self.agent_pos]) 178 | self.grid.set(*obj.cur_pos, None) 179 | obj.cur_pos = poss 180 | self.grid.set(*poss, obj) 181 | return obj 182 | 183 | def _maybe_spawn(self): 184 | new_objs = [t for t in TRASH if t not in \ 185 | [o.name for o in self.objs]] 186 | if np.random.rand() < 0.1 * len(new_objs): 187 | trash = random.choice(new_objs) 188 | obj = Pickable(trash, self.textures[trash], invisible=True) 189 | poss = random.choice([ 190 | pos for pos in self.layout.valid_poss["obj"] \ 191 | if pos != obj.cur_pos and self.grid.get(*pos) is None \ 192 | and pos != self.agent_pos]) 193 | self.place_obj(obj, top=poss, size=(1,1), max_tries=5) 194 | self.objs.append(obj) 195 | return obj 196 | return None 197 | 198 | def _maybe_unsafe(self): 199 | if len(self.unsafe_poss) == 0 and np.random.rand() < self.p_unsafe: 200 | can = random.choice([o for o in self.objs if isinstance(o, Storage)]) 201 | self.unsafe_poss = set() 202 | self.unsafe_name = can.name 203 | for x in [can.cur_pos[0] - 1, can.cur_pos[0], can.cur_pos[0] + 1]: 204 | for y in [can.cur_pos[1] - 1, can.cur_pos[1], can.cur_pos[1] + 1]: 205 | self.unsafe_poss.add((x,y)) 206 | self.unsafe_end = self.step_count + random.randint(1, 10) 207 | return can 208 | return None 209 | 210 | def reset(self, *, seed=None, options=None): 211 | self.prev_action = "Reset" 212 | obs, info = super().reset(seed=seed, options=options) 213 | # All events in the episode so far 214 | self.all_events = [] 215 | self.step_count = 0 216 | self.unsafe_poss = {} 217 | self.unsafe_end = -1 218 | self.unsafe_name = None 219 | info = { 220 | "symbolic_state": self.get_full_symbolic_state(), 221 | "events": [] 222 | } 223 | return obs, info 224 | 225 | def step(self, action): 226 | self.step_count += 1 227 | 228 | reward = 0 229 | terminated = False 230 | truncated = False 231 | success = None 232 | events = [] 233 | 234 | # Step all the object states 235 | for obj in self.objs: 236 | obj.tick() 237 | 238 | # Get the position in front of the agent 239 | fwd_pos = self.front_pos 240 | 241 | # Get the contents of the cell in front of the agent 242 | fwd_cell = self.grid.get(*fwd_pos) 243 | fwd_floor = self.grid.get_floor(*fwd_pos) 244 | 245 | if action == self.actions.left or \ 246 | action == self.actions.right or \ 247 | action == self.actions.up or \ 248 | action == self.actions.down: 249 | self.agent_dir = HomeGridBase.ac2dir[action] 250 | # Get the position in front of the agent after turning 251 | fwd_pos = self.front_pos 252 | 253 | # Get the contents of the cell in front of the agent 254 | fwd_cell = self.grid.get(*fwd_pos) 255 | fwd_floor = self.grid.get_floor(*fwd_pos) 256 | 257 | if (fwd_cell is None or fwd_cell.agent_can_overlap()) and \ 258 | (fwd_floor is None or fwd_floor.agent_can_overlap()): 259 | self.agent_pos = tuple(fwd_pos) 260 | 261 | # Pick up an object 262 | elif action == self.actions.pickup: 263 | if isinstance(fwd_cell, Pickable) and fwd_cell.can_pickup(): 264 | if self.carrying is None: 265 | self.carrying = fwd_cell 266 | self.carrying.cur_pos = (-1, -1) 267 | self.grid.set(fwd_pos[0], fwd_pos[1], None) 268 | 269 | # Drop an object 270 | elif action == self.actions.drop: 271 | if not fwd_cell and self.carrying: 272 | self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying) 273 | self.carrying.cur_pos = tuple(fwd_pos) 274 | self.carrying = None 275 | elif isinstance(fwd_cell, Storage) and self.carrying: 276 | succeeded = fwd_cell.interact(self.actions(action).name, 277 | obj=self.carrying) 278 | if succeeded: 279 | self.carrying = None 280 | 281 | elif action == self.actions.get: 282 | if not self.carrying and isinstance(fwd_cell, Storage): 283 | obj = fwd_cell.interact(self.actions(action).name) 284 | if obj: 285 | self.carrying = obj 286 | self.carrying.cur_pos = (-1, -1) 287 | 288 | elif action == self.actions.pedal or action == self.actions.grasp or \ 289 | action == self.actions.lift: 290 | if isinstance(fwd_cell, Storage): 291 | succeeded = fwd_cell.interact(self.actions(action).name) 292 | 293 | else: 294 | raise ValueError(f"Unknown action: {action}") 295 | 296 | if self.step_count >= self.max_steps: 297 | truncated = True 298 | if (terminated or truncated) and success is None: 299 | success = False 300 | 301 | if self.render_mode == "human": 302 | self.render_with_text() 303 | obs = self.gen_obs() 304 | 305 | # For rendering purposes 306 | self.prev_action = HomeGridBase.Actions(action).name 307 | if success: 308 | self.done_condition = "success" 309 | elif truncated: 310 | self.done_condition = "truncated" 311 | 312 | # Random events 313 | if self.agent_pos in self.unsafe_poss: 314 | terminated = True 315 | reward = -1 316 | 317 | if len(self.unsafe_poss) > 0 and self.step_count == self.unsafe_end: 318 | self.unsafe_poss = {} 319 | self.unsafe_name = None 320 | self.unsafe_end = -1 321 | events.append({ 322 | "type": "termination", 323 | "description": f"i cleaned the spill", 324 | }) 325 | elif len(self.unsafe_poss) == 0: 326 | obj = self._maybe_unsafe() 327 | if obj: 328 | events.append({ 329 | "type": "termination", 330 | "description": f"spill near the {self.unsafe_name}", 331 | }) 332 | else: 333 | events.append({ 334 | "type": "termination", 335 | "description": f"spill near the {self.unsafe_name}", 336 | }) 337 | obj = self._maybe_teleport() 338 | if obj: 339 | room_code = self.cell_to_room[obj.cur_pos] 340 | room_name = room2name[room_code] 341 | events.append({ 342 | "type": "future", "obj": obj, 343 | "room": room_code, 344 | "description": f"i moved the {obj.name} to the {room_name}" 345 | }) 346 | obj = self._maybe_spawn() 347 | if obj: 348 | room_code = self.cell_to_room[obj.cur_pos] 349 | room_name = room2name[room_code] 350 | events.append({ 351 | "type": "future", "obj": obj, 352 | "description": f"there will be {obj.name} in the {room_name} later" 353 | }) 354 | 355 | 356 | self.all_events.append(events) 357 | info = { 358 | "success": success, 359 | "action": action, 360 | "symbolic_state": self.get_full_symbolic_state(), 361 | "events": events, 362 | "all_events": self.all_events, 363 | } 364 | 365 | return obs, reward, terminated, truncated, info 366 | 367 | def get_full_symbolic_state(self) -> Dict: 368 | fwd_pos = self.front_pos 369 | fwd_cell = self.grid.get(*fwd_pos) 370 | if isinstance(fwd_cell, Pickable) or isinstance(fwd_cell, Storage): 371 | front_obj = fwd_cell.name 372 | else: 373 | front_obj = None 374 | 375 | state = { 376 | "step": self.step_count, 377 | "agent": { 378 | "pos": self.agent_pos, 379 | "room": self.cell_to_room[self.agent_pos] if self.agent_pos in self.cell_to_room else None, 380 | "dir": self.agent_dir, 381 | "carrying": self.carrying.name if self.carrying else None 382 | }, 383 | "objects": [ 384 | { 385 | "name": obj.name, 386 | "type": obj.__class__.__name__, 387 | "pos": obj.cur_pos, 388 | "room": self.cell_to_room[obj.cur_pos] if (obj.cur_pos[0] != -1 and obj.cur_pos in self.cell_to_room) else None, 389 | "state": getattr(obj, "state", None), 390 | "action": getattr(obj, "action", None), 391 | "invisible": getattr(obj, "invisible", None), 392 | "contains": [contained_obj.name for contained_obj in obj.contains] if isinstance(obj, Storage) else None, 393 | } for obj in self.objs 394 | ], 395 | "front_obj": front_obj, 396 | "unsafe": { 397 | "name": self.unsafe_name, 398 | "poss": self.unsafe_poss, 399 | "end": self.unsafe_end, 400 | } 401 | } 402 | return state 403 | 404 | def render_with_text(self, text): 405 | img = self._env.render(mode="rgb_array") 406 | img = Image.fromarray(img) 407 | draw = ImageDraw.Draw(img) 408 | draw.text((0, 0), text, (0, 0, 0)) 409 | draw.text((0, 45), "Action: {}".format(self._env.prev_action), (0, 0, 0)) 410 | img = np.asarray(img) 411 | return im 412 | -------------------------------------------------------------------------------- /homegrid/base.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import math 3 | from abc import abstractmethod 4 | from enum import IntEnum 5 | from typing import Any, Callable, Optional, Union 6 | import random 7 | import os 8 | 9 | import gym 10 | import numpy as np 11 | from gym import spaces 12 | from gym.utils import seeding 13 | from PIL import Image 14 | 15 | # Size in pixels of a tile in the full-scale human view 16 | from homegrid.rendering import ( 17 | downsample, 18 | fill_coords, 19 | highlight_img, 20 | point_in_circle, 21 | point_in_line, 22 | point_in_rect, 23 | point_in_triangle, 24 | rotate_fn, 25 | draw_obj, 26 | ) 27 | from homegrid.window import Window 28 | 29 | TILE_PIXELS = 32 30 | SHOW_GRIDLINES = False 31 | # Draw robot instead of default triangle for agent 32 | USE_AGENT_TEXTURE = True 33 | if USE_AGENT_TEXTURE: 34 | AGENT_TEXTURE = np.asarray(Image.open( 35 | f"{os.path.dirname(__file__)}/assets/robot.png" 36 | )) 37 | # Center the agent in the view 38 | CENTERED_VIEW = True 39 | # Map of agent direction indices to vectors 40 | DIR_TO_VEC = [ 41 | # Pointing right (positive X) 42 | np.array((1, 0)), 43 | # Down (positive Y) 44 | np.array((0, 1)), 45 | # Pointing left (negative X) 46 | np.array((-1, 0)), 47 | # Up (negative Y) 48 | np.array((0, -1)), 49 | ] 50 | 51 | 52 | class WorldObj: 53 | """ 54 | Base class for grid world objects 55 | """ 56 | 57 | def __init__(self, name): 58 | self.name = name 59 | self.contains = None 60 | 61 | # Initial position of the object 62 | self.init_pos = None 63 | 64 | # Current position of the object 65 | self.cur_pos = None 66 | 67 | def can_overlap(self): 68 | """Can the agent overlap with this?""" 69 | return False 70 | 71 | def can_pickup(self): 72 | """Can the agent pick this up?""" 73 | return False 74 | 75 | def can_contain(self): 76 | """Can this contain another object?""" 77 | return False 78 | 79 | def see_behind(self): 80 | """Can the agent see behind this object?""" 81 | return True 82 | 83 | def toggle(self, env, pos): 84 | """Method to trigger/toggle an action this object performs""" 85 | return False 86 | 87 | def render(self, r): 88 | """Draw this object with the given renderer""" 89 | raise NotImplementedError 90 | 91 | def tick(self): 92 | pass 93 | 94 | def encode(self): 95 | """Tuple encoding of this object for render caching.""" 96 | return (self.name,) 97 | 98 | 99 | ## Homegridv2 100 | class Storage(WorldObj): 101 | 102 | def __init__(self, name, textures, state=None, action=None, 103 | contains=None, reset_broken_after=20): 104 | super().__init__(name.replace("_", " ")) 105 | self.textures = { 106 | **textures, 107 | "broken": np.rot90(textures["closed"])} 108 | self.contains = contains or [] 109 | # valid states {"open", "closed", "broken"} 110 | self.state = state if state else \ 111 | random.choice(["open", "closed"]) 112 | self.action = action if action else \ 113 | random.choice(["pedal", "grasp", "lift"]) 114 | self.broken_t = 0 115 | self.reset_broken_after = reset_broken_after 116 | 117 | def agent_can_overlap(self): 118 | return False 119 | 120 | def can_overlap(self): 121 | return True 122 | 123 | def render(self, img): 124 | draw_obj(img, self.textures[self.state]) 125 | 126 | def encode(self): 127 | return (self.name, self.state) 128 | 129 | def _get_contents(self): 130 | if len(self.contains) == 0: 131 | return None 132 | return self.contains.pop() 133 | 134 | def interact(self, action, obj=None): 135 | if action == "get": 136 | if self.state != "open": 137 | return False 138 | return self._get_contents() 139 | elif action == "drop": 140 | if self.state != "open": 141 | return False 142 | if len(self.contains) == 0 and obj: 143 | obj.cur_pos = (-1, -1) 144 | self.contains.append(obj) 145 | return True 146 | elif action in {"pedal", "grasp", "lift"}: 147 | if self.state == "closed": 148 | if action != self.action: 149 | self.state = "broken" 150 | self.broken_t = 0 151 | else: 152 | self.state = "open" 153 | return True 154 | else: 155 | raise NotImplementedError(f"Attempting to interact with {action}") 156 | return False 157 | 158 | def tick(self): 159 | if self.state == "broken": 160 | self.broken_t += 1 161 | if self.broken_t == self.reset_broken_after: 162 | self.state = "closed" 163 | self.broken_t = 0 164 | 165 | 166 | class Pickable(WorldObj): 167 | 168 | def __init__(self, name, texture, invisible=False): 169 | super().__init__(name.replace("_", " ")) 170 | self.texture = texture 171 | self.invisible = invisible 172 | # If invisible, make visible after N steps 173 | self.invisible_count = 5 174 | 175 | def agent_can_overlap(self): 176 | return self.invisible 177 | 178 | def can_overlap(self): 179 | return True 180 | 181 | def can_pickup(self): 182 | return True 183 | 184 | def render(self, img): 185 | if self.invisible: return 186 | draw_obj(img, self.texture) 187 | 188 | def encode(self): 189 | return (self.name, self.invisible) 190 | 191 | def tick(self): 192 | if self.invisible: 193 | self.invisible_count -= 1 194 | if self.invisible_count == 0: 195 | self.invisible = False 196 | 197 | 198 | class Inanimate(WorldObj): 199 | 200 | def __init__(self, name, texture, can_overlap=False): 201 | super().__init__(name) 202 | self.texture = texture 203 | self._can_overlap = can_overlap 204 | 205 | def agent_can_overlap(self): 206 | return False 207 | 208 | def can_overlap(self): 209 | return self._can_overlap 210 | 211 | def render(self, img): 212 | draw_obj(img, self.texture) 213 | 214 | def encode(self): 215 | return (self.name,) 216 | 217 | 218 | class FloorWithObject(WorldObj): 219 | 220 | def __init__(self, name, texture, agent_can_overlap, 221 | can_overlap): 222 | super().__init__(name) 223 | self.texture = texture 224 | self._agent_can_overlap = agent_can_overlap 225 | self._can_overlap = can_overlap 226 | 227 | def agent_can_overlap(self): 228 | return self._agent_can_overlap 229 | 230 | def can_overlap(self): 231 | return self._can_overlap 232 | 233 | def render(self, img): 234 | draw_obj(img, self.texture) 235 | 236 | def encode(self): 237 | return (self.name,) 238 | 239 | 240 | class Wall(WorldObj): 241 | def __init__(self): 242 | super().__init__("wall") 243 | 244 | def agent_can_overlap(self): 245 | return False 246 | 247 | def see_behind(self): 248 | return False 249 | 250 | def render(self, img): 251 | fill_coords(img, point_in_rect(0, 1, 0, 1), 252 | np.array([100, 100, 100])) 253 | 254 | 255 | class Grid: 256 | """ 257 | Represent a grid and operations on it 258 | """ 259 | 260 | # Static cache of pre-renderer tiles 261 | tile_cache = {} 262 | 263 | def __init__(self, width, height): 264 | assert width >= 3 265 | assert height >= 3 266 | 267 | self.width = width 268 | self.height = height 269 | 270 | self.grid = [None] * width * height 271 | self.floor_grid = [None] * width * height 272 | 273 | # def __contains__(self, key): 274 | # if isinstance(key, WorldObj): 275 | # for e in self.grid: 276 | # if e is key: 277 | # return True 278 | # elif isinstance(key, tuple): 279 | # for e in self.grid: 280 | # if e is None: 281 | # continue 282 | # if (e.color, e.type) == key: 283 | # return True 284 | # if key[0] is None and key[1] == e.type: 285 | # return True 286 | # return False 287 | 288 | def __eq__(self, other): 289 | grid1 = self.encode() 290 | grid2 = other.encode() 291 | return np.array_equal(grid2, grid1) 292 | 293 | def __ne__(self, other): 294 | return not self == other 295 | 296 | def copy(self): 297 | from copy import deepcopy 298 | 299 | return deepcopy(self) 300 | 301 | def set(self, i, j, v): 302 | assert i >= 0 and i < self.width 303 | assert j >= 0 and j < self.height 304 | self.grid[j * self.width + i] = v 305 | 306 | def set_floor(self, i, j, v): 307 | assert i >= 0 and i < self.width 308 | assert j >= 0 and j < self.height 309 | self.floor_grid[j * self.width + i] = v 310 | 311 | def get(self, i, j): 312 | assert i >= 0 and i < self.width 313 | assert j >= 0 and j < self.height 314 | return self.grid[j * self.width + i] 315 | 316 | def get_floor(self, i, j): 317 | assert i >= 0 and i < self.width 318 | assert j >= 0 and j < self.height 319 | return self.floor_grid[j * self.width + i] 320 | 321 | def horz_wall(self, x, y, length=None, obj_type=Wall): 322 | if length is None: 323 | length = self.width - x 324 | for i in range(0, length): 325 | self.set(x + i, y, obj_type()) 326 | 327 | def vert_wall(self, x, y, length=None, obj_type=Wall): 328 | if length is None: 329 | length = self.height - y 330 | for j in range(0, length): 331 | self.set(x, y + j, obj_type()) 332 | 333 | def wall_rect(self, x, y, w, h): 334 | self.horz_wall(x, y, w) 335 | self.horz_wall(x, y + h - 1, w) 336 | self.vert_wall(x, y, h) 337 | self.vert_wall(x + w - 1, y, h) 338 | 339 | def rotate_left(self): 340 | """ 341 | Rotate the grid to the left (counter-clockwise) 342 | """ 343 | 344 | grid = Grid(self.height, self.width) 345 | 346 | for i in range(self.width): 347 | for j in range(self.height): 348 | v = self.get(i, j) 349 | floor = self.get_floor(i, j) 350 | grid.set(j, grid.height - 1 - i, v) 351 | grid.set_floor(j, grid.height - 1 - i, floor) 352 | 353 | return grid 354 | 355 | def slice(self, topX, topY, width, height): 356 | """ 357 | Get a subset of the grid 358 | """ 359 | 360 | grid = Grid(width, height) 361 | 362 | for j in range(0, height): 363 | for i in range(0, width): 364 | x = topX + i 365 | y = topY + j 366 | 367 | if x >= 0 and x < self.width and y >= 0 and y < self.height: 368 | v = self.get(x, y) 369 | floor = self.get_floor(x, y) 370 | else: 371 | v = Wall() 372 | floor = None 373 | 374 | grid.set(i, j, v) 375 | grid.set_floor(i, j, floor) 376 | 377 | return grid 378 | 379 | @classmethod 380 | def render_tile( 381 | cls, obj, floor, agent_dir=None, highlight=False, 382 | tile_size=TILE_PIXELS, subdivs=3, bgcolor="white", pov_dir=None): 383 | """ 384 | Render a tile and cache the result 385 | """ 386 | 387 | # Hash map lookup key for the cache 388 | key = (agent_dir, highlight, tile_size) 389 | key = floor.encode() + key if floor else key 390 | key = obj.encode() + key if obj else key 391 | # rotate objects depending on agent perspective 392 | key = (pov_dir,) + key if pov_dir else key 393 | 394 | if key in cls.tile_cache: 395 | return cls.tile_cache[key] 396 | 397 | # Render floor tile 398 | if bgcolor == "black": 399 | img = np.zeros( 400 | shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8 401 | ) 402 | elif bgcolor == "white": 403 | img = np.ones( 404 | shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8 405 | ) * 255 406 | else: 407 | raise ValueError("Unknown color") 408 | 409 | if floor is not None: 410 | floor.render(img) 411 | 412 | # Draw the grid lines (top and left edges) 413 | if SHOW_GRIDLINES: 414 | fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100)) 415 | fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100)) 416 | 417 | if obj is not None: 418 | obj.render(img) 419 | 420 | # Rotate object textures depending on agent perspective 421 | if not CENTERED_VIEW: 422 | if pov_dir is not None: 423 | img = np.rot90(img, k=(pov_dir + 1) % 4) 424 | 425 | # Overlay the agent on top 426 | if agent_dir is not None: 427 | if USE_AGENT_TEXTURE: 428 | draw_obj(img, AGENT_TEXTURE) 429 | # Show direction indicator 430 | if CENTERED_VIEW: 431 | tri_fn = point_in_triangle( 432 | (0.65, 0.29), 433 | (0.87, 0.50), 434 | (0.65, 0.71), 435 | ) 436 | tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir) 437 | fill_coords(img, tri_fn, (255, 0, 0)) 438 | else: 439 | tri_fn = point_in_triangle( 440 | (0.12, 0.19), 441 | (0.87, 0.50), 442 | (0.12, 0.81), 443 | ) 444 | 445 | # Rotate the agent based on its direction 446 | tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir) 447 | fill_coords(img, tri_fn, (255, 0, 0)) 448 | 449 | # Highlight the cell if needed 450 | if highlight: 451 | highlight_img(img) 452 | 453 | # Downsample the image to perform supersampling/anti-aliasing 454 | img = downsample(img, subdivs) 455 | 456 | # Cache the rendered tile 457 | cls.tile_cache[key] = img 458 | 459 | return img 460 | 461 | def render(self, tile_size, agent_pos, agent_dir=None, highlight_mask=None, 462 | pov_dir=None): 463 | """ 464 | Render this grid at a given scale 465 | :param r: target renderer object 466 | :param tile_size: tile size in pixels 467 | """ 468 | 469 | if highlight_mask is None: 470 | highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool) 471 | 472 | # Compute the total grid size 473 | width_px = self.width * tile_size 474 | height_px = self.height * tile_size 475 | 476 | img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8) 477 | 478 | # Render the grid 479 | for j in range(0, self.height): 480 | for i in range(0, self.width): 481 | cell = self.get(i, j) 482 | floor = self.floor_grid[j * self.width + i] 483 | 484 | agent_here = np.array_equal(agent_pos, (i, j)) 485 | tile_img = Grid.render_tile( 486 | cell, 487 | floor, 488 | agent_dir=agent_dir if agent_here else None, 489 | highlight=highlight_mask[i, j], 490 | tile_size=tile_size, 491 | pov_dir=pov_dir, 492 | subdivs=1, 493 | ) 494 | 495 | ymin = j * tile_size 496 | ymax = (j + 1) * tile_size 497 | xmin = i * tile_size 498 | xmax = (i + 1) * tile_size 499 | img[ymin:ymax, xmin:xmax, :] = tile_img 500 | 501 | return img 502 | 503 | def process_vis(self, agent_pos): 504 | mask = np.zeros(shape=(self.width, self.height), dtype=bool) 505 | 506 | mask[agent_pos[0], agent_pos[1]] = True 507 | 508 | for j in reversed(range(0, self.height)): 509 | for i in range(0, self.width - 1): 510 | if not mask[i, j]: 511 | continue 512 | 513 | cell = self.get(i, j) 514 | if cell and not cell.see_behind(): 515 | continue 516 | 517 | mask[i + 1, j] = True 518 | if j > 0: 519 | mask[i + 1, j - 1] = True 520 | mask[i, j - 1] = True 521 | 522 | for i in reversed(range(1, self.width)): 523 | if not mask[i, j]: 524 | continue 525 | 526 | cell = self.get(i, j) 527 | if cell and not cell.see_behind(): 528 | continue 529 | 530 | mask[i - 1, j] = True 531 | if j > 0: 532 | mask[i - 1, j - 1] = True 533 | mask[i, j - 1] = True 534 | 535 | for j in range(0, self.height): 536 | for i in range(0, self.width): 537 | if not mask[i, j]: 538 | self.set(i, j, None) 539 | self.set_floor(i, j, None) 540 | 541 | return mask 542 | 543 | 544 | class MiniGridEnv(gym.Env): 545 | """ 546 | 2D grid world game environment 547 | """ 548 | 549 | metadata = { 550 | "render_modes": ["human", "rgb_array"], 551 | "render_fps": 10, 552 | } 553 | 554 | # Enumeration of possible actions 555 | class Actions(IntEnum): 556 | # Turn left, turn right, move forward 557 | left = 0 558 | right = 1 559 | forward = 2 560 | # Pick up an object 561 | pickup = 3 562 | # Drop an object 563 | drop = 4 564 | # Toggle/activate an object 565 | toggle = 5 566 | 567 | # Done completing task 568 | done = 6 569 | 570 | def __init__( 571 | self, 572 | grid_size: int = None, 573 | width: int = None, 574 | height: int = None, 575 | max_steps: int = 100, 576 | see_through_walls: bool = True, 577 | agent_view_size: int = 7, 578 | render_mode: Optional[str] = None, 579 | highlight: bool = True, 580 | tile_size: int = TILE_PIXELS, 581 | agent_pov: bool = False, 582 | ): 583 | # Can't set both grid_size and width/height 584 | if grid_size: 585 | assert width is None and height is None 586 | width = grid_size 587 | height = grid_size 588 | 589 | # Action enumeration for this environment 590 | self.actions = MiniGridEnv.Actions 591 | 592 | # Actions are discrete integer values 593 | self.action_space = spaces.Discrete(len(self.actions)) 594 | 595 | # Number of cells (width and height) in the agent view 596 | assert agent_view_size % 2 == 1 597 | assert agent_view_size >= 3 598 | self.agent_view_size = agent_view_size 599 | 600 | # Observations are dictionaries containing an 601 | # encoding of the grid 602 | image_observation_space = spaces.Box( 603 | low=0, 604 | high=255, 605 | shape=(self.agent_view_size, self.agent_view_size, 3), 606 | dtype="uint8", 607 | ) 608 | self.observation_space = spaces.Dict( 609 | { 610 | "image": image_observation_space, 611 | "direction": spaces.Discrete(4), 612 | } 613 | ) 614 | 615 | # Range of possible rewards 616 | self.reward_range = (0, 1) 617 | 618 | self.window: Window = None 619 | 620 | # Environment configuration 621 | self.width = width 622 | self.height = height 623 | self.max_steps = max_steps 624 | self.see_through_walls = see_through_walls 625 | 626 | # Current position and direction of the agent 627 | self.agent_pos: np.ndarray = None 628 | self.agent_dir: int = None 629 | 630 | # Current grid and mission and carrying 631 | self.grid = Grid(width, height) 632 | self.carrying = None 633 | 634 | # Rendering attributes 635 | self.render_mode = render_mode 636 | self.highlight = highlight 637 | self.tile_size = tile_size 638 | self.agent_pov = agent_pov 639 | 640 | def reset(self, *, seed=None, options=None): 641 | super().reset(seed=seed) 642 | 643 | # Reinitialize episode-specific variables 644 | self.agent_pos = (-1, -1) 645 | self.agent_dir = -1 646 | 647 | # Generate a new random grid at the start of each episode 648 | self._gen_grid(self.width, self.height) 649 | 650 | # These fields should be defined by _gen_grid 651 | assert ( 652 | self.agent_pos >= (0, 0) 653 | if isinstance(self.agent_pos, tuple) 654 | else all(self.agent_pos >= 0) and self.agent_dir >= 0 655 | ) 656 | 657 | # Check that the agent doesn't overlap with an object 658 | start_cell = self.grid.get(*self.agent_pos) 659 | assert start_cell is None or start_cell.can_overlap() 660 | 661 | # Item picked up, being carried, initially nothing 662 | self._init_inventory() 663 | 664 | # Step count since episode start 665 | self.step_count = 0 666 | 667 | if self.render_mode == "human": 668 | self.render() 669 | 670 | # Return first observation 671 | obs = self.gen_obs() 672 | 673 | return obs, {} 674 | 675 | def _init_inventory(self): 676 | self.carrying = None 677 | 678 | def hash(self, size=16): 679 | """Compute a hash that uniquely identifies the current state of the environment. 680 | :param size: Size of the hashing 681 | """ 682 | sample_hash = hashlib.sha256() 683 | 684 | to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir] 685 | for item in to_encode: 686 | sample_hash.update(str(item).encode("utf8")) 687 | 688 | return sample_hash.hexdigest()[:size] 689 | 690 | @property 691 | def steps_remaining(self): 692 | return self.max_steps - self.step_count 693 | 694 | @abstractmethod 695 | def _gen_grid(self, width, height): 696 | pass 697 | 698 | def _reward(self): 699 | """ 700 | Compute the reward to be given upon success 701 | """ 702 | 703 | return 1 - 0.9 * (self.step_count / self.max_steps) 704 | 705 | def _rand_int(self, low, high): 706 | """ 707 | Generate random integer in [low,high[ 708 | """ 709 | 710 | return self.np_random.integers(low, high) 711 | 712 | def _rand_float(self, low, high): 713 | """ 714 | Generate random float in [low,high[ 715 | """ 716 | 717 | return self.np_random.uniform(low, high) 718 | 719 | def _rand_bool(self): 720 | """ 721 | Generate random boolean value 722 | """ 723 | 724 | return self.np_random.integers(0, 2) == 0 725 | 726 | def _rand_elem(self, iterable): 727 | """ 728 | Pick a random element in a list 729 | """ 730 | 731 | lst = list(iterable) 732 | idx = self._rand_int(0, len(lst)) 733 | return lst[idx] 734 | 735 | def _rand_subset(self, iterable, num_elems): 736 | """ 737 | Sample a random subset of distinct elements of a list 738 | """ 739 | 740 | lst = list(iterable) 741 | assert num_elems <= len(lst) 742 | 743 | out = [] 744 | 745 | while len(out) < num_elems: 746 | elem = self._rand_elem(lst) 747 | lst.remove(elem) 748 | out.append(elem) 749 | 750 | return out 751 | 752 | def _rand_pos(self, xLow, xHigh, yLow, yHigh): 753 | """ 754 | Generate a random (x,y) position tuple 755 | """ 756 | 757 | return ( 758 | self.np_random.integers(xLow, xHigh), 759 | self.np_random.integers(yLow, yHigh), 760 | ) 761 | 762 | def place_obj(self, obj, top=None, size=None, reject_fn=None, max_tries=math.inf): 763 | """ 764 | Place an object at an empty position in the grid 765 | 766 | :param top: top-left position of the rectangle where to place 767 | :param size: size of the rectangle where to place 768 | :param reject_fn: function to filter out potential positions 769 | """ 770 | 771 | if top is None: 772 | top = (0, 0) 773 | else: 774 | top = (max(top[0], 0), max(top[1], 0)) 775 | 776 | if size is None: 777 | size = (self.grid.width, self.grid.height) 778 | 779 | num_tries = 0 780 | 781 | while True: 782 | # This is to handle with rare cases where rejection sampling 783 | # gets stuck in an infinite loop 784 | if num_tries > max_tries: 785 | raise RecursionError("rejection sampling failed in place_obj") 786 | 787 | num_tries += 1 788 | 789 | pos = np.array( 790 | ( 791 | self._rand_int(top[0], min(top[0] + size[0], self.grid.width)), 792 | self._rand_int(top[1], min(top[1] + size[1], self.grid.height)), 793 | ) 794 | ) 795 | 796 | pos = tuple(pos) 797 | 798 | # Don't place the object on top of another object 799 | if self.grid.get(*pos) is not None and not self.grid.get(*pos).can_overlap(): 800 | continue 801 | 802 | # Don't place the object where the agent is 803 | if np.array_equal(pos, self.agent_pos): 804 | continue 805 | 806 | # Check if there is a filtering criterion 807 | if reject_fn and reject_fn(self, pos): 808 | continue 809 | 810 | break 811 | 812 | self.grid.set(pos[0], pos[1], obj) 813 | 814 | if obj is not None: 815 | obj.init_pos = pos 816 | obj.cur_pos = pos 817 | 818 | return pos 819 | 820 | def put_obj(self, obj, i, j): 821 | """ 822 | Put an object at a specific position in the grid 823 | """ 824 | 825 | self.grid.set(i, j, obj) 826 | obj.init_pos = (i, j) 827 | obj.cur_pos = (i, j) 828 | 829 | def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf): 830 | """ 831 | Set the agent's starting point at an empty position in the grid 832 | """ 833 | 834 | self.agent_pos = (-1, -1) 835 | pos = self.place_obj(None, top, size, max_tries=max_tries) 836 | self.agent_pos = pos 837 | 838 | if rand_dir: 839 | self.agent_dir = self._rand_int(0, 4) 840 | 841 | return pos 842 | 843 | @property 844 | def dir_vec(self): 845 | """ 846 | Get the direction vector for the agent, pointing in the direction 847 | of forward movement. 848 | """ 849 | 850 | assert self.agent_dir >= 0 and self.agent_dir < 4 851 | return DIR_TO_VEC[self.agent_dir] 852 | 853 | @property 854 | def right_vec(self): 855 | """ 856 | Get the vector pointing to the right of the agent. 857 | """ 858 | 859 | dx, dy = self.dir_vec 860 | return np.array((-dy, dx)) 861 | 862 | @property 863 | def front_pos(self): 864 | """ 865 | Get the position of the cell that is right in front of the agent 866 | """ 867 | 868 | return self.agent_pos + self.dir_vec 869 | 870 | def get_view_coords(self, i, j): 871 | """ 872 | Translate and rotate absolute grid coordinates (i, j) into the 873 | agent's partially observable view (sub-grid). Note that the resulting 874 | coordinates may be negative or outside of the agent's view size. 875 | """ 876 | 877 | ax, ay = self.agent_pos 878 | dx, dy = self.dir_vec 879 | rx, ry = self.right_vec 880 | 881 | # Compute the absolute coordinates of the top-left view corner 882 | sz = self.agent_view_size 883 | hs = self.agent_view_size // 2 884 | tx = ax + (dx * (sz - 1)) - (rx * hs) 885 | ty = ay + (dy * (sz - 1)) - (ry * hs) 886 | 887 | lx = i - tx 888 | ly = j - ty 889 | 890 | # Project the coordinates of the object relative to the top-left 891 | # corner onto the agent's own coordinate system 892 | vx = rx * lx + ry * ly 893 | vy = -(dx * lx + dy * ly) 894 | 895 | return vx, vy 896 | 897 | def get_view_exts(self, agent_view_size=None): 898 | """ 899 | Get the extents of the square set of tiles visible to the agent 900 | Note: the bottom extent indices are not included in the set 901 | if agent_view_size is None, use self.agent_view_size 902 | """ 903 | 904 | agent_view_size = agent_view_size or self.agent_view_size 905 | 906 | if not CENTERED_VIEW: 907 | # Facing right 908 | if self.agent_dir == 0: 909 | topX = self.agent_pos[0] 910 | topY = self.agent_pos[1] - agent_view_size // 2 911 | # Facing down 912 | elif self.agent_dir == 1: 913 | topX = self.agent_pos[0] - agent_view_size // 2 914 | topY = self.agent_pos[1] 915 | # Facing left 916 | elif self.agent_dir == 2: 917 | topX = self.agent_pos[0] - agent_view_size + 1 918 | topY = self.agent_pos[1] - agent_view_size // 2 919 | # Facing up 920 | elif self.agent_dir == 3: 921 | topX = self.agent_pos[0] - agent_view_size // 2 922 | topY = self.agent_pos[1] - agent_view_size + 1 923 | else: 924 | assert False, "invalid agent direction" 925 | 926 | botX = topX + agent_view_size 927 | botY = topY + agent_view_size 928 | else: 929 | topX = self.agent_pos[0] - 1 930 | topY = self.agent_pos[1] - 1 931 | botX = self.agent_pos[0] + 2 932 | botY = self.agent_pos[1] + 2 933 | return (topX, topY, botX, botY) 934 | 935 | def relative_coords(self, x, y): 936 | """ 937 | Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates 938 | """ 939 | 940 | vx, vy = self.get_view_coords(x, y) 941 | 942 | if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size: 943 | return None 944 | 945 | return vx, vy 946 | 947 | def in_view(self, x, y): 948 | """ 949 | check if a grid position is visible to the agent 950 | """ 951 | 952 | return self.relative_coords(x, y) is not None 953 | 954 | def step(self, action): 955 | self.step_count += 1 956 | 957 | reward = 0 958 | terminated = False 959 | truncated = False 960 | 961 | # Get the position in front of the agent 962 | fwd_pos = self.front_pos 963 | 964 | # Get the contents of the cell in front of the agent 965 | fwd_cell = self.grid.get(*fwd_pos) 966 | 967 | # Rotate left 968 | if action == self.actions.left: 969 | self.agent_dir -= 1 970 | if self.agent_dir < 0: 971 | self.agent_dir += 4 972 | 973 | # Rotate right 974 | elif action == self.actions.right: 975 | self.agent_dir = (self.agent_dir + 1) % 4 976 | 977 | # Move forward 978 | elif action == self.actions.forward: 979 | if fwd_cell is None or fwd_cell.can_overlap(): 980 | self.agent_pos = tuple(fwd_pos) 981 | if fwd_cell is not None and fwd_cell.type == "goal": 982 | terminated = True 983 | reward = self._reward() 984 | if fwd_cell is not None and fwd_cell.type == "lava": 985 | terminated = True 986 | 987 | # Pick up an object 988 | elif action == self.actions.pickup: 989 | if fwd_cell and fwd_cell.can_pickup(): 990 | if self.carrying is None: 991 | self.carrying = fwd_cell 992 | self.carrying.cur_pos = np.array([-1, -1]) 993 | self.grid.set(fwd_pos[0], fwd_pos[1], None) 994 | 995 | # Drop an object 996 | elif action == self.actions.drop: 997 | if not fwd_cell and self.carrying: 998 | self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying) 999 | self.carrying.cur_pos = fwd_pos 1000 | self.carrying = None 1001 | 1002 | # Toggle/activate an object 1003 | elif action == self.actions.toggle: 1004 | if fwd_cell: 1005 | fwd_cell.toggle(self, fwd_pos) 1006 | 1007 | # Done action (not used by default) 1008 | elif action == self.actions.done: 1009 | pass 1010 | 1011 | else: 1012 | raise ValueError(f"Unknown action: {action}") 1013 | 1014 | if self.step_count >= self.max_steps: 1015 | truncated = True 1016 | 1017 | if self.render_mode == "human": 1018 | self.render() 1019 | 1020 | obs = self.gen_obs() 1021 | 1022 | return obs, reward, terminated, truncated, {} 1023 | 1024 | def gen_obs_grid(self, agent_view_size=None): 1025 | """ 1026 | Generate the sub-grid observed by the agent. 1027 | This method also outputs a visibility mask telling us which grid 1028 | cells the agent can actually see. 1029 | if agent_view_size is None, self.agent_view_size is used 1030 | """ 1031 | 1032 | topX, topY, botX, botY = self.get_view_exts(agent_view_size) 1033 | 1034 | agent_view_size = agent_view_size or self.agent_view_size 1035 | 1036 | grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size) 1037 | 1038 | if not CENTERED_VIEW: 1039 | for i in range(self.agent_dir + 1): 1040 | grid = grid.rotate_left() 1041 | 1042 | # Process occluders and visibility 1043 | # Note that this incurs some performance cost 1044 | if not self.see_through_walls: 1045 | vis_mask = grid.process_vis( 1046 | agent_pos=(agent_view_size // 2, agent_view_size - 1) 1047 | ) 1048 | else: 1049 | vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool) 1050 | 1051 | # Make it so the agent sees what it's carrying 1052 | # We do this by placing the carried object at the agent's position 1053 | # in the agent's partially observable view 1054 | if CENTERED_VIEW: 1055 | agent_pos = grid.width // 2, grid.height // 2 1056 | else: 1057 | agent_pos = grid.width // 2, grid.height - 1 1058 | if self.carrying: 1059 | grid.set(*agent_pos, self.carrying) 1060 | else: 1061 | grid.set(*agent_pos, None) 1062 | 1063 | return grid, vis_mask 1064 | 1065 | def gen_obs(self): 1066 | """ 1067 | Generate the agent's view (partially observable, low-resolution encoding) 1068 | """ 1069 | 1070 | grid, vis_mask = self.gen_obs_grid() 1071 | 1072 | # Encode the partially observable view into a numpy array 1073 | # image = grid.encode(vis_mask) 1074 | image = None 1075 | 1076 | # Observations are dictionaries containing: 1077 | # - an image (partially observable view of the environment) 1078 | # - the agent's direction/orientation (acting as a compass) 1079 | # - a textual mission string (instructions for the agent) 1080 | obs = {"image": image, "direction": self.agent_dir} 1081 | 1082 | return obs 1083 | 1084 | def get_pov_render(self, tile_size): 1085 | """ 1086 | Render an agent's POV observation for visualization 1087 | """ 1088 | grid, vis_mask = self.gen_obs_grid() 1089 | 1090 | agent_pos = (self.agent_view_size // 2, self.agent_view_size // 2) if \ 1091 | CENTERED_VIEW else (self.agent_view_size // 2, self.agent_view_size - 1) 1092 | agent_dir = self.agent_dir if CENTERED_VIEW else 3 1093 | 1094 | # Render the whole grid 1095 | img = grid.render( 1096 | tile_size, 1097 | agent_pos=agent_pos, 1098 | agent_dir=agent_dir, 1099 | highlight_mask=vis_mask, 1100 | pov_dir=self.agent_dir, 1101 | ) 1102 | 1103 | return img 1104 | 1105 | def get_full_render(self, highlight, tile_size): 1106 | """ 1107 | Render a non-paratial observation for visualization 1108 | """ 1109 | # Compute which cells are visible to the agent 1110 | _, vis_mask = self.gen_obs_grid() 1111 | 1112 | # Mask of which cells to highlight 1113 | highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool) 1114 | 1115 | # Compute the world coordinates of the bottom-left corner 1116 | # of the agent's view area 1117 | if CENTERED_VIEW: 1118 | for abs_i in range(self.agent_pos[0] - 1, self.agent_pos[0] + 2): 1119 | for abs_j in range(self.agent_pos[1] - 1, self.agent_pos[1] + 2): 1120 | highlight_mask[abs_i, abs_j] = True 1121 | else: 1122 | f_vec = self.dir_vec 1123 | r_vec = self.right_vec 1124 | top_left = ( 1125 | self.agent_pos 1126 | + f_vec * (self.agent_view_size - 1) 1127 | - r_vec * (self.agent_view_size // 2) 1128 | ) 1129 | 1130 | # For each cell in the visibility mask 1131 | for vis_j in range(0, self.agent_view_size): 1132 | for vis_i in range(0, self.agent_view_size): 1133 | # If this cell is not visible, don't highlight it 1134 | if not vis_mask[vis_i, vis_j]: 1135 | continue 1136 | 1137 | # Compute the world coordinates of this cell 1138 | abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i) 1139 | 1140 | if abs_i < 0 or abs_i >= self.width: 1141 | continue 1142 | if abs_j < 0 or abs_j >= self.height: 1143 | continue 1144 | 1145 | # Mark this cell to be highlighted 1146 | highlight_mask[abs_i, abs_j] = True 1147 | 1148 | # Render the whole grid 1149 | img = self.grid.render( 1150 | tile_size, 1151 | self.agent_pos, 1152 | self.agent_dir, 1153 | highlight_mask=highlight_mask if highlight else None, 1154 | ) 1155 | 1156 | return img 1157 | 1158 | def get_frame( 1159 | self, 1160 | highlight: bool = True, 1161 | tile_size: int = TILE_PIXELS, 1162 | agent_pov: bool = False, 1163 | ): 1164 | """Returns an RGB image corresponding to the whole environment or the agent's point of view. 1165 | 1166 | Args: 1167 | 1168 | highlight (bool): If true, the agent's field of view or point of view is highlighted with a lighter gray color. 1169 | tile_size (int): How many pixels will form a tile from the NxM grid. 1170 | agent_pov (bool): If true, the rendered frame will only contain the point of view of the agent. 1171 | 1172 | Returns: 1173 | 1174 | frame (np.ndarray): A frame of type numpy.ndarray with shape (x, y, 3) representing RGB values for the x-by-y pixel image. 1175 | 1176 | """ 1177 | if agent_pov: 1178 | return self.get_pov_render(tile_size) 1179 | else: 1180 | return self.get_full_render(highlight, tile_size) 1181 | 1182 | def render(self): 1183 | 1184 | img = self.get_frame(self.highlight, self.tile_size, self.agent_pov) 1185 | 1186 | if self.render_mode == "human": 1187 | if self.window is None: 1188 | self.window = Window("minigrid") 1189 | self.window.show(block=False) 1190 | self.window.show_img(img) 1191 | elif self.render_mode == "rgb_array": 1192 | return img 1193 | 1194 | def close(self): 1195 | if self.window: 1196 | self.window.close() 1197 | --------------------------------------------------------------------------------