├── sso ├── __init__.py ├── env │ ├── nethack │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── base.py │ │ └── maps.py │ ├── __init__.py │ └── scienceworld.py ├── memory │ ├── examples.py │ ├── __init__.py │ └── skillset.py ├── llm │ ├── __init__.py │ └── gpt.py ├── agent │ ├── __init__.py │ ├── fewshot.py │ ├── reflexion.py │ └── skills.py ├── utils.py ├── trajectory.py └── skill.py ├── .gitignore ├── static ├── images │ ├── sso.png │ ├── uci.png │ ├── usage.png │ ├── incontext.png │ ├── scienceworld.png │ ├── sso_example.png │ ├── ai2-logo-header.png │ ├── ai2_website_top.png │ └── aristo-logo-header.png ├── js │ ├── index.js │ ├── bulma-slider.min.js │ ├── bulma-slider.js │ └── bulma-carousel.min.js └── css │ ├── index.css │ ├── bulma-carousel.min.css │ └── bulma-slider.min.css ├── requirements.txt ├── README.md ├── main.py ├── LICENSE └── index.html /sso/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sso/env/nethack/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv/ 2 | results/ 3 | .vscode/ 4 | __pycache__/ 5 | -------------------------------------------------------------------------------- /static/images/sso.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/sso/HEAD/static/images/sso.png -------------------------------------------------------------------------------- /static/images/uci.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/sso/HEAD/static/images/uci.png -------------------------------------------------------------------------------- /static/images/usage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/sso/HEAD/static/images/usage.png -------------------------------------------------------------------------------- /static/images/incontext.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/sso/HEAD/static/images/incontext.png -------------------------------------------------------------------------------- /static/images/scienceworld.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/sso/HEAD/static/images/scienceworld.png -------------------------------------------------------------------------------- /static/images/sso_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/sso/HEAD/static/images/sso_example.png -------------------------------------------------------------------------------- /static/images/ai2-logo-header.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/sso/HEAD/static/images/ai2-logo-header.png -------------------------------------------------------------------------------- /static/images/ai2_website_top.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/sso/HEAD/static/images/ai2_website_top.png -------------------------------------------------------------------------------- /static/images/aristo-logo-header.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/sso/HEAD/static/images/aristo-logo-header.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | openai==0.28.0 2 | git+https://github.com/allenai/ScienceWorld.git@exhaustivevalidactions 3 | matplotlib==3.5.3 4 | networkx==2.6.3 5 | minihack==0.1.4 6 | nle-language-wrapper==0.2.0 -------------------------------------------------------------------------------- /sso/memory/examples.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from sso.trajectory import Trajectory 4 | from sso.memory import Memory 5 | 6 | 7 | class ExamplesMemory(Memory): 8 | 9 | def __init__(self, **kwargs): 10 | super().__init__(**kwargs) 11 | self.best_reward = float("-inf") 12 | 13 | def insert(self, trajectory: Trajectory) -> None: 14 | reward = sum(x.reward for x in trajectory if x.reward is not None) 15 | if reward >= self.best_reward: 16 | self.trajectories.append(trajectory) 17 | self.best_reward = max(reward, self.best_reward) 18 | best_trajectories = [] 19 | for traj in self.trajectories: 20 | if sum(x.reward for x in traj if x.reward is not None) == self.best_reward: 21 | best_trajectories.append(traj) 22 | self.trajectories = best_trajectories 23 | 24 | def get_memories(self, trajectory: Trajectory = None, n: int = None) -> List[Trajectory]: 25 | if n is None: 26 | n = len(self.trajectories) 27 | return self.trajectories[-n:] 28 | -------------------------------------------------------------------------------- /sso/llm/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | from sso.llm.gpt import get_response as get_response_gpt, get_embedding as get_embedding_gpt 4 | 5 | 6 | global DEFAULT_MODEL 7 | DEFAULT_MODEL = None 8 | 9 | global DEFAULT_TEMP 10 | DEFAULT_TEMP = None 11 | 12 | global DEFAULT_EMBEDDING 13 | DEFAULT_EMBEDDING = None 14 | 15 | 16 | def set_default_model(model: str = None, temp: float = None, embedding: str = None) -> None: 17 | if model is not None: 18 | global DEFAULT_MODEL 19 | DEFAULT_MODEL = model 20 | 21 | if temp is not None: 22 | global DEFAULT_TEMP 23 | DEFAULT_TEMP = temp 24 | 25 | if embedding is not None: 26 | global DEFAULT_EMBEDDING 27 | DEFAULT_EMBEDDING = embedding 28 | 29 | 30 | def query_llm(messages: List[Dict[str, str]], model: str = None, temperature: float = 1, **generation_kwargs) -> str: 31 | if model is None: 32 | global DEFAULT_MODEL 33 | model = DEFAULT_MODEL 34 | if temperature is None: 35 | global DEFAULT_TEMP 36 | temperature = DEFAULT_TEMP 37 | if model.startswith("gpt"): 38 | return get_response_gpt(model, messages, temperature=temperature, **generation_kwargs) 39 | 40 | 41 | def get_embedding(content: str, model: str = None) -> List[float]: 42 | if model is None: 43 | global DEFAULT_EMBEDDING 44 | model = DEFAULT_EMBEDDING 45 | 46 | if model.startswith("text"): 47 | return get_embedding_gpt(model, content) 48 | -------------------------------------------------------------------------------- /sso/env/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Any 2 | from abc import ABC, abstractmethod, abstractproperty 3 | 4 | from sso.trajectory import State 5 | 6 | 7 | class Env(ABC): 8 | 9 | @abstractproperty 10 | def max_steps(self) -> int: 11 | """ 12 | :return: maximum number of steps in the simulation 13 | """ 14 | raise NotImplementedError 15 | 16 | @abstractproperty 17 | def num_train(self) -> int: 18 | """ 19 | :return: number of training variants 20 | """ 21 | raise NotImplementedError 22 | 23 | @abstractproperty 24 | def num_test(self) -> int: 25 | """ 26 | :return: number of test variants 27 | """ 28 | raise NotImplementedError 29 | 30 | @abstractproperty 31 | def train_ids(self) -> Tuple[str]: 32 | """ 33 | :return: training task IDs 34 | """ 35 | raise NotImplementedError 36 | 37 | @abstractproperty 38 | def test_ids(self) -> Tuple[str]: 39 | """ 40 | :return: test task IDs 41 | """ 42 | raise NotImplementedError 43 | 44 | @abstractmethod 45 | def reset(self, task_id: str = None, test: bool = False, **kwargs) -> Tuple[State, Dict[str, Any]]: 46 | """ 47 | Reset the simulation. 48 | 49 | :param test: whether to use a test task 50 | :param kwargs: additional arguments 51 | :return: initial state and info 52 | """ 53 | raise NotImplementedError 54 | 55 | @abstractmethod 56 | def step(self, action: str) -> Tuple[State, float, bool, Dict[str, Any]]: 57 | """ 58 | Perform an action in the simulation. 59 | 60 | :param action: action to perform 61 | :return: next state, reward, done, info 62 | """ 63 | raise NotImplementedError 64 | -------------------------------------------------------------------------------- /sso/llm/gpt.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | import os 3 | import time 4 | from functools import lru_cache 5 | import openai 6 | openai.api_key = os.environ["OPENAI_API_KEY"] 7 | 8 | 9 | def get_response(model: str, messages: List[Dict[str, str]], max_tries=50, temperature=1, **kwargs) -> str: 10 | completion = None 11 | num_tries = 0 12 | while not completion and num_tries < max_tries: 13 | try: 14 | completion = openai.ChatCompletion.create( 15 | model=model, 16 | messages=messages, 17 | **kwargs 18 | ).choices[0].message.content 19 | break 20 | except Exception as e: 21 | num_tries += 1 22 | print("try {}: {}".format(num_tries, e)) 23 | if "maximum context length" in str(e): 24 | if len(messages) > 3: 25 | if messages[0]["role"] == "system": 26 | messages = [messages[0]] + messages[3:] 27 | else: 28 | messages = messages[2:] 29 | else: 30 | raise RuntimeError("messages too long") 31 | time.sleep(2) 32 | if not completion: 33 | raise RuntimeError("Failed to get response from API") 34 | return completion 35 | 36 | 37 | @lru_cache(maxsize=1000) 38 | def get_embedding(model: str, content: str, max_tries=50) -> List[float]: 39 | embedding = None 40 | num_tries = 0 41 | while not embedding and num_tries < max_tries: 42 | try: 43 | embedding = openai.Embedding.create(model=model, input=content).data[0].embedding 44 | break 45 | except Exception as e: 46 | num_tries += 1 47 | print("try {}: {}".format(num_tries, e)) 48 | time.sleep(2) 49 | if not embedding: 50 | raise RuntimeError("Failed to get embedding response from API") 51 | return embedding 52 | -------------------------------------------------------------------------------- /sso/agent/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Dict 2 | 3 | from sso.trajectory import Trajectory 4 | from sso.memory import Memory 5 | 6 | 7 | class Agent: 8 | 9 | def __init__(self, memory: Memory = None, max_history: int = 10, trim_states: bool = True): 10 | assert memory is not None 11 | self.memory = memory 12 | self._log = [] 13 | self.freeze_memory = False 14 | self.max_history = max_history 15 | self.trim_states = trim_states 16 | 17 | def train(self, train: bool = True) -> None: 18 | self.freeze_memory = not train 19 | 20 | def clear(self) -> None: 21 | self.memory.clear() 22 | 23 | def act(self, trajectory: Trajectory) -> str: 24 | self.step_log() 25 | self.log("reward", trajectory[-1].reward) 26 | self.log("last_action", trajectory[-1].last_action) 27 | self.log("state_description", trajectory[-1].state_description) 28 | 29 | # Get next action 30 | return None 31 | 32 | def _update_memory(self, trajectory: Trajectory) -> None: 33 | raise NotImplementedError 34 | 35 | def record_done(self, trajectory: Trajectory) -> None: 36 | self.step_log() 37 | self.log("reward", trajectory[-1].reward) 38 | self.log("last_action", trajectory[-1].last_action) 39 | self.log("state_description", trajectory[-1].state_description) 40 | if not self.freeze_memory: 41 | self._update_memory(trajectory) 42 | 43 | def save(self, save_dir: str) -> None: 44 | self.memory.save(save_dir) 45 | 46 | def load(self, load_dir: str, rebuild: bool = False) -> None: 47 | self.memory.load(load_dir, rebuild=rebuild) 48 | 49 | def log(self, key: str, value: Any) -> None: 50 | if len(self._log) == 0: 51 | self.step_log() 52 | if key in self._log[-1]: 53 | if isinstance(self._log[-1][key], list): 54 | self._log[-1][key].append(value) 55 | else: 56 | self._log[-1][key] = [self._log[-1][key], value] 57 | else: 58 | self._log[-1][key] = value 59 | 60 | def step_log(self) -> None: 61 | self._log.append(dict()) 62 | 63 | def reset_log(self) -> None: 64 | self._log = [] 65 | 66 | def get_log(self) -> List[Dict[str, Any]]: 67 | return self._log 68 | -------------------------------------------------------------------------------- /sso/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import List, TYPE_CHECKING 3 | if TYPE_CHECKING: 4 | from sso.trajectory import State 5 | 6 | from functools import lru_cache 7 | import numpy as np 8 | 9 | from sso.llm import get_embedding 10 | 11 | 12 | @lru_cache(maxsize=100000) 13 | def clean_feature(feature: str, remove_fill_words=False) -> str: 14 | words_to_skip = set(["to", "the", "for", "on", "in", "a", "an", ""]) if remove_fill_words else set([""]) 15 | words = [] 16 | for word in feature.lower().split(): 17 | word = word.strip(".,!?\"')(}{][:; \t\n") 18 | if word not in words_to_skip: 19 | words.append(word) 20 | return " ".join(words) 21 | 22 | 23 | @lru_cache(maxsize=100000) 24 | def _get_emedding_similarity(text1: str, text2: str) -> float: 25 | if clean_feature(text1, remove_fill_words=True) == clean_feature(text2, remove_fill_words=True): 26 | return 1.0 27 | embedding1 = np.array(get_embedding(text1)) 28 | embedding2 = np.array(get_embedding(text2)) 29 | return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) 30 | 31 | 32 | def get_similarity(text1: str, text2: str) -> float: 33 | return _get_emedding_similarity(text1, text2) 34 | 35 | 36 | def get_feature_similarity(feature: str, features: List[str]) -> float: 37 | if isinstance(features, str): 38 | features = [features] 39 | return max(get_similarity(feature, feature2) for feature2 in features) 40 | 41 | 42 | def get_similar_action(action: str, actions: List[str]) -> str: 43 | for x in actions: 44 | if clean_feature(action, remove_fill_words=True) == clean_feature(x, remove_fill_words=True): 45 | return x 46 | return None 47 | 48 | 49 | def get_state_similarity(state1: State, state2: State, init_state: bool = False) -> float: 50 | if state1.last_action is None and state2.last_action is None or init_state: 51 | return _get_emedding_similarity(state1.state_description, state2.state_description) 52 | elif state1.last_action is not None and state2.last_action is not None: 53 | return _get_emedding_similarity( 54 | "You chose to {}.\n\n".format(state1.last_action) + state1.state_description, 55 | "You chose to {}.\n\n".format(state2.last_action) + state2.state_description 56 | ) 57 | else: 58 | return 0 59 | -------------------------------------------------------------------------------- /sso/env/nethack/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union, List 2 | from nle_language_wrapper.nle_language_obsv import NLELanguageObsv 3 | 4 | 5 | NLE_LANG = NLELanguageObsv() 6 | 7 | 8 | ACTION_TEMPLATES = [ 9 | "k: move north", 10 | "l: move east", 11 | "j: move south", 12 | "h: move west", 13 | "y: move northwest", 14 | "u: move northeast", 15 | "b: move southwest", 16 | "n: move southeast", 17 | ",: pick up at current location", 18 | "d: drop item", 19 | "o: open door", 20 | "t: throw item", 21 | "e: eat food", 22 | "w: weild weapon", 23 | "W: wear armor", 24 | "T: take off armor", 25 | "P: put on accessory", 26 | "R: remove accessory", 27 | "q: quaff/drink potion", 28 | "a: apply/use item", 29 | "z: zap wand" 30 | ] 31 | 32 | 33 | def get_message(obs) -> str: 34 | message = NLE_LANG.text_message(obs["tty_chars"]).decode("latin-1") 35 | if message is None: 36 | message = "" 37 | return message.replace("\n", "; ") 38 | 39 | 40 | def get_vision(obs) -> str: 41 | vision = NLE_LANG.text_glyphs(obs["glyphs"], obs["blstats"]).decode("latin-1") 42 | for dir1 in ["east", "west", "north", "south"]: 43 | for dir2 in ["northwest", "northeast", "southwest", "southeast"]: 44 | vision = vision.replace(dir1 + dir2, dir2) 45 | return vision 46 | 47 | 48 | def get_lang_obs(obs: Dict, as_list: bool = False, use_stats: bool = False) -> Union[str, List[str]]: 49 | text_fields = { 50 | "text_glyphs": get_vision(obs), 51 | "text_message": get_message(obs), 52 | "text_blstats": NLE_LANG.text_blstats(obs["blstats"]).decode("latin-1"), 53 | "text_inventory": NLE_LANG.text_inventory(obs["inv_strs"], obs["inv_letters"]).decode("latin-1"), 54 | "text_cursor": NLE_LANG.text_cursor(obs["glyphs"], obs["blstats"], obs["tty_cursor"]).decode("latin-1"), 55 | } 56 | 57 | lang_obs = ["You have " + x[3:] for x in text_fields["text_inventory"].split("\n") if x] 58 | if use_stats: 59 | lang_obs += [x for x in text_fields["text_blstats"].split("\n") if x] 60 | lang_obs += [text_fields["text_cursor"]] 61 | lang_obs += ["You see a " + x for x in text_fields["text_glyphs"].split("\n") if x] 62 | if text_fields["text_message"]: 63 | lang_obs += [text_fields["text_message"]] 64 | if as_list: 65 | return lang_obs 66 | else: 67 | return ". ".join(lang_obs) 68 | -------------------------------------------------------------------------------- /static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | var INTERP_BASE = "./static/interpolation/stacked"; 4 | var NUM_INTERP_FRAMES = 240; 5 | 6 | var interp_images = []; 7 | function preloadInterpolationImages() { 8 | for (var i = 0; i < NUM_INTERP_FRAMES; i++) { 9 | var path = INTERP_BASE + '/' + String(i).padStart(6, '0') + '.jpg'; 10 | interp_images[i] = new Image(); 11 | interp_images[i].src = path; 12 | } 13 | } 14 | 15 | function setInterpolationImage(i) { 16 | var image = interp_images[i]; 17 | image.ondragstart = function() { return false; }; 18 | image.oncontextmenu = function() { return false; }; 19 | $('#interpolation-image-wrapper').empty().append(image); 20 | } 21 | 22 | 23 | $(document).ready(function() { 24 | // Check for click events on the navbar burger icon 25 | $(".navbar-burger").click(function() { 26 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu" 27 | $(".navbar-burger").toggleClass("is-active"); 28 | $(".navbar-menu").toggleClass("is-active"); 29 | 30 | }); 31 | 32 | var options = { 33 | slidesToScroll: 1, 34 | slidesToShow: 3, 35 | loop: true, 36 | infinite: true, 37 | autoplay: false, 38 | autoplaySpeed: 3000, 39 | } 40 | 41 | // Initialize all div with carousel class 42 | var carousels = bulmaCarousel.attach('.carousel', options); 43 | 44 | // Loop on each carousel initialized 45 | for(var i = 0; i < carousels.length; i++) { 46 | // Add listener to event 47 | carousels[i].on('before:show', state => { 48 | console.log(state); 49 | }); 50 | } 51 | 52 | // Access to bulmaCarousel instance of an element 53 | var element = document.querySelector('#my-element'); 54 | if (element && element.bulmaCarousel) { 55 | // bulmaCarousel instance is available as element.bulmaCarousel 56 | element.bulmaCarousel.on('before-show', function(state) { 57 | console.log(state); 58 | }); 59 | } 60 | 61 | /*var player = document.getElementById('interpolation-video'); 62 | player.addEventListener('loadedmetadata', function() { 63 | $('#interpolation-slider').on('input', function(event) { 64 | console.log(this.value, player.duration); 65 | player.currentTime = player.duration / 100 * this.value; 66 | }) 67 | }, false);*/ 68 | preloadInterpolationImages(); 69 | 70 | $('#interpolation-slider').on('input', function(event) { 71 | setInterpolationImage(this.value); 72 | }); 73 | setInterpolationImage(0); 74 | $('#interpolation-slider').prop('max', NUM_INTERP_FRAMES - 1); 75 | 76 | bulmaSlider.attach(); 77 | 78 | }) 79 | -------------------------------------------------------------------------------- /sso/agent/fewshot.py: -------------------------------------------------------------------------------- 1 | from sso.trajectory import Trajectory 2 | from sso.llm import query_llm 3 | from sso.agent import Agent 4 | 5 | 6 | class FewshotAgent(Agent): 7 | 8 | def __init__(self, fewshot: int = 3, **kwargs): 9 | super().__init__(**kwargs) 10 | self.fewshot = fewshot 11 | 12 | def act(self, trajectory: Trajectory) -> str: 13 | super().act(trajectory) 14 | 15 | # Get next action 16 | return self._act(trajectory) 17 | 18 | def _update_memory(self, trajectory: Trajectory) -> None: 19 | self.memory.insert(trajectory) 20 | 21 | def _act(self, trajectory: Trajectory) -> str: 22 | sub_trajectory = trajectory.slice(-self.max_history, None) 23 | 24 | system_message = "You are playing a text-based game in which you must interact with your surroundings to complete a task.\n\n" 25 | system_message += sub_trajectory.task_description 26 | system_message += "\n\nGiven the state, reflect on what has happened so far, explain your plan to accomplish the task and then output the next action to execute (use one of the action templates below)." 27 | system_message += "\n\nFor example:\nThe last action had the effect of... To accomplish the task, I will need to...\nCurrent subgoal: [subgoal]\nNext action: [action]" 28 | if sub_trajectory[-1].action_prompt is not None: 29 | system_message += "\n\n" + sub_trajectory[-1].action_prompt 30 | 31 | if len(self.memory.trajectories) > 0: 32 | examples = self.memory.get_memories(n=self.fewshot) 33 | system_message += "\n\nUse the following example trajector{} to help you accomplish the task:".format( 34 | "ies" if len(examples) > 1 else "y" 35 | ) 36 | for traj in examples: 37 | system_message += "\n\n" + traj.build_string(trim_state=self.trim_states) 38 | 39 | messages = [dict(role="system", content=system_message)] 40 | 41 | state_strings = sub_trajectory.build_string(use_task=False, use_actions=False, as_list=True, trim_state=self.trim_states) 42 | for i in range(len(sub_trajectory)): 43 | state = sub_trajectory[i] 44 | 45 | if i > 0: 46 | action_start = state.last_generation.lower().find("next action:") 47 | if action_start != -1: 48 | action_start += len("next action:") 49 | action_text = state.last_generation[:action_start+1] + " " + state.last_action 50 | else: 51 | action_text = "To accomplish the task, I will need to {}. Next action: {}".format(state.last_action, state.last_action) 52 | messages.append(dict(role="assistant", content=action_text)) 53 | 54 | prompt = state_strings[i] 55 | messages.append(dict(role="user", content=prompt)) 56 | 57 | response = query_llm(messages) 58 | self.log("action_messages", messages) 59 | self.log("action_generation", response) 60 | return response 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Skill Set Optimization 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |

10 | 11 | ## Installation 12 | 13 | pip install python library requirements with: 14 | 15 | ``` 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## Usage 20 | 21 | Executing the `main.py` file with the environment variable `OPENAI_API_KEY` will run SSO or baselines on ScienceWorld or NetHack tasks. 22 | 23 | ``` 24 | OPENAI_API_KEY=[your key] python main.py \ 25 | --agent skills \ 26 | --memory skills \ 27 | --env scienceworld \ 28 | --task measure-melting-point-known-substance \ 29 | --train_variant_count 10 \ 30 | --test_variants 399 354 390 418 335 409 377 410 385 367 31 | ``` 32 | 33 | By default, the script will train the agent for 30 iterations and then test its ability to transfer to the specified test variants. 34 | Use `--test_init` and `--test_adapt` to also test the base agent's performance and the agent's ability to adapt to each of the test tasks in at most 5 iterations. 35 | 36 | ### Agent 37 | 38 | Execute the SSO agent by using the below command line arguments: 39 | 40 | ``` 41 | --agent skills 42 | --memory skills 43 | ``` 44 | 45 | Execute a Reflexion baseline using: 46 | 47 | ``` 48 | --agent reflexion 49 | ``` 50 | 51 | Execute a Fewshot baseline using: 52 | 53 | ``` 54 | --agent fewshot 55 | --memory examples 56 | ``` 57 | 58 | ### Environment 59 | 60 | #### NetHack 61 | 62 | Evaluate on our custom NetHack task using: 63 | 64 | ``` 65 | --env nethack 66 | --task MiniHack-KeyLavaCross-v0 67 | --max_history 5 68 | --full_states 69 | --test_iters 10 70 | --train_iters 30 71 | ``` 72 | 73 | Other NetHack environments can be used by altering the `task` argument. 74 | However, SSO has only been tested with `MiniHack-KeyLavaCross-v0`. 75 | We recommend setting the `max_history` and `full_states` arguments for NetHack as shown above. 76 | 77 | #### ScienceWorld 78 | 79 | Evaluate on a ScienceWorld task and variants using: 80 | 81 | ``` 82 | --env scienceworld 83 | --task measure-melting-point-known-substance 84 | --train_variant_count 10 85 | --test_variants 399 354 390 418 335 409 377 410 385 367 86 | --test_iters 1 87 | --train_iters 3 88 | ``` 89 | 90 | Note that `task` can be set to any valid ScienceWorld task, and variants can be selected randomly with `train_variant_count`/`test_variant_count` or by specifying specific variants with `train_variants`/`test_variants`. 91 | 92 | ## Citation 93 | ```bib 94 | @article{nottingham2024sso, 95 | author = "Nottingham, Kolby and Majumder, Bodhisattwa Prasad and Dalvi Mishra, Bhavana and Singh, Sameer and Clark, Peter and Fox, Roy", 96 | title = "Skill Set Optimization: Reinforcing Language Model Behavior via Transferable Skills", 97 | journal = "arXiv", 98 | year = "2024", 99 | url = "https://arxiv.org/abs/2402.03244" 100 | } 101 | ``` -------------------------------------------------------------------------------- /sso/memory/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | 7 | from sso.trajectory import Trajectory 8 | from sso.skill import Skill 9 | 10 | 11 | class Memory: 12 | 13 | def __init__( 14 | self, 15 | max_trajectories: int = 10, 16 | skill_memory_scale: float = 2.0, 17 | max_retrieval: int = 3, 18 | discount: float = 0.9, 19 | ): 20 | self.max_trajectories = max_trajectories 21 | self.skill_memory_scale = skill_memory_scale 22 | self.max_retrieval = max_retrieval 23 | self.discount = discount 24 | 25 | self.trajectories: List[Trajectory] = [] 26 | self.memory: List[Skill] = [] 27 | self.sampled: List[Skill] = [] 28 | self.skill_feedback: List[List[Tuple[Skill, float, bool]]] = [[]] 29 | 30 | def insert(self, trajectory: Trajectory) -> Trajectory: 31 | raise NotImplementedError 32 | 33 | def get_memories(self, trajectory: Trajectory = None, n: int = None) -> List[Union[Trajectory, Skill]]: 34 | raise NotImplementedError 35 | 36 | def build(self, trajectories: Union[Trajectory, List[Trajectory]] = [], **kwargs) -> None: 37 | if isinstance(trajectories, Trajectory): 38 | trajectories = [trajectories] 39 | print("Inserting trajectories...") 40 | for trajectory in tqdm(trajectories): 41 | self.insert(trajectory) 42 | 43 | def clear(self) -> None: 44 | self.trajectories: List[Trajectory] = [] 45 | self.memory: List[Skill] = [] 46 | self.sampled: List[Skill] = [] 47 | self.skill_feedback: List[List[Tuple[Skill, float, bool]]] = [[]] 48 | 49 | def save(self, save_path: str) -> None: 50 | os.makedirs(save_path, exist_ok=True) 51 | with open(os.path.join(save_path, "skill_info.json"), "w") as f: 52 | json.dump([ 53 | s.info() for s in self.memory 54 | ], f, indent=4) 55 | with open(os.path.join(save_path, "memory.json"), "w") as f: 56 | json.dump(dict( 57 | trajectories=[s.to_dict() for s in self.trajectories], 58 | skills=[s.to_dict() for s in self.memory], 59 | skill_feedback=[[[r, x, s.to_dict()] for s, r, x in episode] for episode in self.skill_feedback], 60 | ), f, indent=4) 61 | 62 | def load(self, load_path: str, rebuild: bool = False) -> None: 63 | with open(os.path.join(load_path, "memory.json"), "r") as f: 64 | data = json.load(f) 65 | self.trajectories = [Trajectory.from_dict(x) for x in data["trajectories"]] 66 | if rebuild: 67 | self.build() 68 | else: 69 | self.memory = [] 70 | for skill_data in data["skills"]: 71 | skill = Skill.from_dict(skill_data) 72 | skill.build() 73 | self.memory.append(skill) 74 | self.skill_feedback = [[]] 75 | for episode in data["skill_feedback"]: 76 | self.skill_feedback.append([]) 77 | for r, x, skill_data in episode: 78 | skill = Skill.from_dict(skill_data) 79 | skill.build() 80 | self.skill_feedback[-1].append((skill, r, x)) 81 | -------------------------------------------------------------------------------- /static/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Noto Sans', sans-serif; 3 | font-size: 15px; 4 | color:rgb(0, 0, 0); 5 | } 6 | 7 | /* .column { 8 | float: left; 9 | width: 50%; 10 | padding: 5px; 11 | } 12 | 13 | /* Clear floats after image containers */ 14 | /* .row::after { 15 | content: ""; 16 | clear: both; 17 | display: table; 18 | } */ 19 | 20 | .skill-table td + td { 21 | border-left:2px solid black; 22 | } 23 | 24 | img { 25 | display: block; 26 | margin-left: auto; 27 | margin-right: auto } 28 | 29 | .footer .icon-link { 30 | font-size: 25px; 31 | color: #000; 32 | } 33 | 34 | .link-block a { 35 | margin-top: 5px; 36 | margin-bottom: 5px; 37 | } 38 | 39 | .dnerf { 40 | font-variant: small-caps; 41 | } 42 | 43 | 44 | .teaser .hero-body { 45 | padding-top: 0; 46 | padding-bottom: 3rem; 47 | } 48 | 49 | .teaser { 50 | font-family: 'Google Sans', sans-serif; 51 | } 52 | 53 | 54 | .publication-title { 55 | } 56 | 57 | .publication-banner { 58 | max-height: parent; 59 | 60 | } 61 | 62 | .publication-banner video { 63 | position: relative; 64 | left: auto; 65 | top: auto; 66 | transform: none; 67 | object-fit: fit; 68 | } 69 | 70 | .publication-header .hero-body { 71 | } 72 | 73 | .publication-title { 74 | font-family: 'Google Sans', sans-serif; 75 | } 76 | 77 | .publication-authors { 78 | font-family: 'Google Sans', sans-serif; 79 | } 80 | 81 | .publication-venue { 82 | color: #555; 83 | width: fit-content; 84 | font-weight: bold; 85 | } 86 | 87 | .publication-awards { 88 | color: #ff3860; 89 | width: fit-content; 90 | font-weight: bolder; 91 | } 92 | 93 | .publication-authors { 94 | } 95 | 96 | .publication-authors a { 97 | color: hsl(229, 86%, 53%) !important; 98 | } 99 | 100 | .publication-authors a:hover { 101 | text-decoration: underline; 102 | } 103 | 104 | .author-block { 105 | display: inline-block; 106 | } 107 | 108 | .publication-banner img { 109 | } 110 | 111 | .publication-authors { 112 | /*color: #4286f4;*/ 113 | } 114 | 115 | .publication-video { 116 | position: relative; 117 | width: 100%; 118 | height: 0; 119 | padding-bottom: 56.25%; 120 | 121 | overflow: hidden; 122 | border-radius: 10px !important; 123 | } 124 | 125 | .publication-video iframe { 126 | position: absolute; 127 | top: 0; 128 | left: 0; 129 | width: 100%; 130 | height: 100%; 131 | } 132 | 133 | .publication-body img { 134 | } 135 | 136 | .results-carousel { 137 | overflow: hidden; 138 | } 139 | 140 | .results-carousel .item { 141 | margin: 5px; 142 | overflow: hidden; 143 | border: 1px solid #bbb; 144 | border-radius: 10px; 145 | padding: 0; 146 | font-size: 0; 147 | } 148 | 149 | .results-carousel video { 150 | margin: 0; 151 | } 152 | 153 | 154 | .interpolation-panel { 155 | background: #f5f5f5; 156 | border-radius: 10px; 157 | } 158 | 159 | .interpolation-panel .interpolation-image { 160 | width: 100%; 161 | border-radius: 5px; 162 | } 163 | 164 | .interpolation-video-column { 165 | } 166 | 167 | .interpolation-panel .slider { 168 | margin: 0 !important; 169 | } 170 | 171 | .interpolation-panel .slider { 172 | margin: 0 !important; 173 | } 174 | 175 | #interpolation-image-wrapper { 176 | width: 100%; 177 | } 178 | #interpolation-image-wrapper img { 179 | border-radius: 5px; 180 | } 181 | -------------------------------------------------------------------------------- /static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /sso/agent/reflexion.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import os 3 | import json 4 | 5 | from sso.trajectory import Trajectory 6 | from sso.llm import query_llm 7 | from sso.agent import Agent 8 | from sso.memory import Memory 9 | 10 | 11 | class ReflexionAgent(Agent): 12 | 13 | def __init__(self, max_reflections: int = 3, **kwargs): 14 | super().__init__(**kwargs) 15 | self.memory = Memory() # Dummy memory 16 | self.max_reflections = max_reflections 17 | self.reflections = [] 18 | 19 | def act(self, trajectory: Trajectory) -> str: 20 | super().act(trajectory) 21 | return self._act(trajectory) 22 | 23 | def save(self, save_dir: str, **kwargs) -> None: 24 | with open(os.path.join(save_dir, "reflections.json"), "w") as f: 25 | json.dump(self.reflections, f, indent=4) 26 | 27 | def load(self, load_dir: str, **kwargs) -> None: 28 | with open(os.path.join(load_dir, "reflections.json"), "r") as f: 29 | self.reflections = json.load(f) 30 | 31 | def _update_memory(self, trajectory: Trajectory) -> None: 32 | if sum(x.reward for x in trajectory) < 1: 33 | messages = self._build_messages(trajectory) 34 | messages[-1]["content"] += "\n\nThe task was not completed successfully. What should you do better next time? Be very concise. Respond with a single sentence." 35 | response = query_llm(messages) 36 | self.log("reflection", response) 37 | self.reflections.append(response) 38 | 39 | def _build_messages(self, trajectory: Trajectory) -> List[dict]: 40 | system_message = "You are playing a text-based game in which you must interact with your surroundings to complete a task.\n\n" 41 | system_message += trajectory.task_description 42 | system_message += "\n\nGiven the state, reflect on what has happened so far, explain your plan to accomplish the task and then output the next action to execute (use one of the action templates below)." 43 | system_message += "\n\nFor example:\nThe last action had the effect of... To accomplish the task, I will need to...\nCurrent subgoal: [subgoal]\nNext action: [action]" 44 | if trajectory[-1].action_prompt is not None: 45 | system_message += "\n\n" + trajectory[-1].action_prompt 46 | 47 | if len(self.reflections) > 0: 48 | system_message += "\n\nConsider the following tips:\n" 49 | system_message += "\n".join(self.reflections[-self.max_reflections:]) 50 | 51 | messages = [dict(role="system", content=system_message)] 52 | 53 | state_strings = trajectory.build_string(use_task=False, use_actions=False, as_list=True, trim_state=self.trim_states) 54 | for i in range(len(trajectory)): 55 | state = trajectory[i] 56 | 57 | if i > 0: 58 | action_start = state.last_generation.lower().find("next action:") 59 | if action_start != -1: 60 | action_start += len("next action:") 61 | action_text = state.last_generation[:action_start+1] + " " + state.last_action 62 | else: 63 | action_text = "To accomplish the task, I will need to {}. Next action: {}".format(state.last_action, state.last_action) 64 | messages.append(dict(role="assistant", content=action_text)) 65 | 66 | prompt = state_strings[i] 67 | messages.append(dict(role="user", content=prompt)) 68 | 69 | return messages 70 | 71 | def _act(self, trajectory: Trajectory) -> str: 72 | sub_trajectory = trajectory.slice(-self.max_history, None) 73 | messages = self._build_messages(sub_trajectory) 74 | response = query_llm(messages) 75 | self.log("action_messages", messages) 76 | self.log("action_generation", response) 77 | return response 78 | -------------------------------------------------------------------------------- /sso/agent/skills.py: -------------------------------------------------------------------------------- 1 | from sso.trajectory import Trajectory 2 | from sso.llm import query_llm 3 | from sso.agent import Agent 4 | from sso.utils import clean_feature 5 | 6 | 7 | class SkillsAgent(Agent): 8 | 9 | def act(self, trajectory: Trajectory) -> str: 10 | super().act(trajectory) 11 | return self._act(trajectory) 12 | 13 | def _update_memory(self, trajectory: Trajectory) -> None: 14 | self.memory.build(trajectory) 15 | 16 | def _act(self, trajectory: Trajectory) -> str: 17 | sub_trajectory = trajectory.slice(-self.max_history, None) 18 | 19 | system_message = "You are playing a text-based game in which you must interact with your surroundings to complete a task. You will occasionally be given posisible subgoals. You may choose to target one of these subgoals or ignore them.\n\n" 20 | system_message += sub_trajectory.task_description 21 | system_message += "\n\nGiven the state, reflect on what has happened so far, explain your plan to accomplish the task, output which of the given subgoals you are targeting next (match one of the subgoals in the prompt word for word or output \"none\"), and then output the next action to execute (use one of the action templates below)." 22 | system_message += "\n\nFor example:\nThe last action had the effect of... To accomplish the task, I will need to...\nCurrent subgoal: [subgoal]\nNext action: [action]" 23 | if sub_trajectory[-1].action_prompt is not None: 24 | system_message += "\n\n" + sub_trajectory[-1].action_prompt 25 | 26 | skills = self.memory.get_memories(trajectory=sub_trajectory) 27 | if len(skills) > 0: 28 | skill_text = "The following instructions contain potentially useful information about reaching subgoals:" 29 | for skill in skills: 30 | skill_text += "\n\nInstructions for reaching the subgoal \"{}\":\n\t".format(skill.target) 31 | skill_text += "\n\t".join("{}. {}".format(i + 1, instruction) for i, instruction in enumerate(skill.instructions)) 32 | system_message += "\n\n" + skill_text 33 | 34 | messages = [dict(role="system", content=system_message)] 35 | 36 | state_strings = sub_trajectory.build_string(use_task=False, use_actions=False, as_list=True, trim_state=self.trim_states) 37 | for i in range(len(sub_trajectory)): 38 | state = sub_trajectory[i] 39 | 40 | if i > 0: 41 | action_start = state.last_generation.lower().find("next action:") 42 | if action_start != -1: 43 | action_start += len("next action:") 44 | action_text = state.last_generation[:action_start+1] + " " + state.last_action 45 | else: 46 | action_text = "To accomplish the task, I will need to {}.\nCurrent subgoal: none\nNext action: {}".format(state.last_action, state.last_action) 47 | messages.append(dict(role="assistant", content=action_text)) 48 | 49 | prompt = state_strings[i] 50 | messages.append(dict(role="user", content=prompt)) 51 | 52 | response = query_llm(messages) 53 | 54 | if "current subgoal:" in response.lower(): 55 | current_skill = clean_feature(response.lower().split("current subgoal:")[-1].split("next action:")[0].strip()) 56 | if current_skill != "none": 57 | for skill in reversed(skills): 58 | if clean_feature(skill.target) == current_skill: 59 | self.memory.log_sampled(len(trajectory) - 1, skill) 60 | self.log("chosen_skill", skill.target) 61 | trajectory[-1].skill_target = skill.target 62 | break 63 | 64 | self.log("action_messages", messages) 65 | self.log("action_generation", response) 66 | return response 67 | -------------------------------------------------------------------------------- /sso/env/nethack/base.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Any, List 2 | import gym 3 | import minihack 4 | from nle.nethack.actions import * 5 | from nle_language_wrapper import NLELanguageWrapper 6 | from sso.env.nethack.utils import ACTION_TEMPLATES 7 | 8 | import sso.env.nethack.maps 9 | from sso.env.nethack.utils import get_message, get_lang_obs 10 | from sso.env import Env 11 | from sso.trajectory import State 12 | 13 | 14 | class NetHackTask(Env): 15 | 16 | def __init__( 17 | self, 18 | task: str = None, 19 | max_steps: int = 100, 20 | **kwargs 21 | ): 22 | assert task is not None, "You must specify a task to load" 23 | self.task = task 24 | self.env = gym.make( 25 | self.task, 26 | observation_keys=("glyphs", "blstats", "tty_chars", "inv_strs", "inv_letters", "tty_cursor") 27 | ) 28 | self.lang_to_action = NLELanguageWrapper(self.env).pre_step 29 | 30 | self._max_steps = max_steps 31 | self.t = 0 32 | self.last_nle_obs = None 33 | 34 | @property 35 | def action_templates(self) -> List[str]: 36 | return ACTION_TEMPLATES 37 | 38 | def parse_action(self, action: str) -> Tuple[str, List[str]]: 39 | env_action = action.split("Next Action:")[-1].split("Next action:")[-1].split("next action:")[-1] 40 | env_action = env_action.strip(" \n\"'([{") 41 | env_action = env_action[0] if len(env_action) > 0 else env_action 42 | return env_action, [env_action] 43 | 44 | def get_action_prompt(self) -> str: 45 | res = "Below is a mapping of keys to common actions. Your generated action should always be a single character.\n\t" 46 | res += "\n\t".join(self.action_templates) 47 | return res 48 | 49 | def reset(self, **kwargs) -> Tuple[State, Dict[str, Any]]: 50 | obs = self.env.reset() 51 | self.t = 0 52 | self.last_nle_obs = obs 53 | info = dict( 54 | task_id=self.task, 55 | task_description="Task Description: " + self.env.TASK_DESCRIPTION, 56 | success=False 57 | ) 58 | return self.get_state(obs, "", "", 0, False), info 59 | 60 | def step(self, action: str) -> Tuple[State, float, bool, Dict[str, Any]]: 61 | parsed_action, env_actions = self.parse_action(action) 62 | 63 | reward, done, info = 0, False, dict() 64 | try: 65 | for env_action in env_actions: 66 | self.last_nle_obs, r, done, info = self.env.step(self.lang_to_action(env_action)) 67 | reward += r 68 | if done: 69 | break 70 | state = self.get_state(self.last_nle_obs, action, parsed_action, reward, done) 71 | except ValueError: 72 | state = State( 73 | observation="Invalid action.", 74 | state_features=["Invalid action."], 75 | state_description="Invalid action.", 76 | last_generation=action, 77 | last_action=parsed_action, 78 | reward=reward, 79 | done=done, 80 | task_description="Task Description: " + self.env.TASK_DESCRIPTION, 81 | action_prompt=self.get_action_prompt(), 82 | all_templates=self.action_templates, 83 | ) 84 | 85 | self.t += 1 86 | if self.t >= self.max_steps: 87 | done = True 88 | info.update(dict( 89 | task_id=self.task, 90 | task_description="Task Description: " + self.env.TASK_DESCRIPTION, 91 | success=info["end_status"] == 2 if "end_status" in info else False, 92 | )) 93 | 94 | return state, reward, done, info 95 | 96 | def get_state(self, obs, generation: str, action: str, reward: float, done: bool) -> State: 97 | return State( 98 | observation=get_message(obs), 99 | state_features=get_lang_obs(obs, as_list=True), 100 | state_description=get_lang_obs(obs, as_list=False), 101 | last_generation=generation, 102 | last_action=action, 103 | reward=reward, 104 | done=done, 105 | task_description="Task Description: " + self.env.TASK_DESCRIPTION, 106 | action_prompt=self.get_action_prompt(), 107 | all_templates=self.action_templates, 108 | ) 109 | 110 | @property 111 | def max_steps(self) -> int: 112 | return self._max_steps 113 | 114 | @property 115 | def num_train(self) -> int: 116 | return 1 117 | 118 | @property 119 | def num_test(self) -> int: 120 | return 1 121 | 122 | @property 123 | def train_ids(self) -> Tuple[str]: 124 | return (self.task,) 125 | 126 | @property 127 | def test_ids(self) -> Tuple[str]: 128 | return (self.task,) 129 | -------------------------------------------------------------------------------- /sso/env/nethack/maps.py: -------------------------------------------------------------------------------- 1 | from minihack import MiniHackSkill 2 | from minihack.envs import register 3 | 4 | from sso.env.nethack.utils import get_message 5 | 6 | 7 | class MiniHackLavacross(MiniHackSkill): 8 | TASK_DESCRIPTION = "find a safe way to cross the lava and navigate to the stairs down. There is an item that can help you behind a locked door." 9 | 10 | def __init__(self, *args, **kwargs): 11 | des_file = """ 12 | MAZE: "mylevel", ' ' 13 | FLAGS:hardfloor 14 | INIT_MAP: solidfill,' ' 15 | GEOMETRY:center,center 16 | MAP 17 | ------------- 18 | |.|...L.....| 19 | |+-...L.....| 20 | |.....L.....| 21 | |.....L.....| 22 | |.....L.....| 23 | ------------- 24 | ENDMAP 25 | REGION:(0,0,12,6),lit,"ordinary" 26 | $left_bank = selection:fillrect (1,3,5,5) 27 | $right_bank = selection:fillrect (7,1,11,5) 28 | IF [50%] { 29 | OBJECT:('=',"levitation"),(1,1),blessed 30 | } ELSE { 31 | OBJECT:('!',"levitation"),(1,1),blessed 32 | } 33 | OBJECT:('(',"skeleton key"),rndcoord($left_bank),blessed,0,name:"The Master Key of Thievery" 34 | DOOR:locked,(1,2) 35 | STAIR:rndcoord($right_bank),down 36 | BRANCH:(1,1,5,5),(1,1,2,2) 37 | """ 38 | self.picked_up_key = False 39 | self.unlocked_door = False 40 | self.used_item = False 41 | super().__init__(*args, des_file=des_file, max_episode_steps=50, character="rog-hum-cha-mal", **kwargs) 42 | 43 | def reset(self, *args, **kwargs): 44 | self.picked_up_key = False 45 | self.unlocked_door = False 46 | self.used_item = False 47 | return super().reset(*args, **kwargs) 48 | 49 | def step(self, action): 50 | obs, reward, done, info = super().step(action) 51 | reward = 0 52 | message = get_message(obs) 53 | if message: 54 | if not self.picked_up_key and "- a key named The Master Key of Thievery." in message: 55 | reward = .1 56 | self.picked_up_key = True 57 | if not self.unlocked_door and "You succeed in unlocking the door." in message: 58 | reward = .2 59 | self.unlocked_door = True 60 | if not self.used_item and any(x in message for x in [ 61 | "a ring of levitation (on right hand)", 62 | "You start to float in the air!", 63 | ]): 64 | reward = .3 65 | self.used_item = True 66 | if "end_status" in info and info["end_status"] == 2: 67 | reward = .4 68 | return obs, reward, done, info 69 | 70 | 71 | class MiniHackLavacross2(MiniHackSkill): 72 | TASK_DESCRIPTION = "find and navigate to the stairs down. During your search you will need to unlock a door using a key and safely cross lava using a magic ring or boots." 73 | 74 | def __init__(self, *args, **kwargs): 75 | des_file = """ 76 | MAZE: "mylevel", ' ' 77 | FLAGS:hardfloor 78 | INIT_MAP: solidfill,' ' 79 | GEOMETRY:center,center 80 | MAP 81 | --------- 82 | |...L...| 83 | |...L...| 84 | |...L...| 85 | |...LLLL| 86 | |...|...| 87 | |...+...| 88 | |...|...| 89 | --------- 90 | ENDMAP 91 | REGION:(0,0,8,8),lit,"ordinary" 92 | $left_room = selection:fillrect (1,1,3,7) 93 | $bottom_room = selection:fillrect (5,5,7,7) 94 | $top_room = selection:fillrect (5,1,7,3) 95 | IF [50%] { 96 | OBJECT:('=',"levitation"),rndcoord($bottom_room),blessed 97 | } ELSE { 98 | OBJECT:('!',"levitation"),rndcoord($bottom_room),blessed 99 | } 100 | OBJECT:('(',"skeleton key"),rndcoord($left_room),blessed,0,name:"The Master Key of Thievery" 101 | DOOR:locked,(4,6) 102 | STAIR:rndcoord($top_room),down 103 | BRANCH:(1,1,3,7),(0,0,0,0) 104 | """ 105 | self.picked_up_key = False 106 | self.unlocked_door = False 107 | self.used_item = False 108 | super().__init__(*args, des_file=des_file, max_episode_steps=50, character="rog-hum-cha-mal", **kwargs) 109 | 110 | def reset(self, *args, **kwargs): 111 | self.picked_up_key = False 112 | self.unlocked_door = False 113 | self.used_item = False 114 | return super().reset(*args, **kwargs) 115 | 116 | def step(self, action): 117 | obs, reward, done, info = super().step(action) 118 | reward = 0 119 | message = get_message(obs) 120 | if message: 121 | if not self.picked_up_key and "- a key named The Master Key of Thievery." in message: 122 | reward = .1 123 | self.picked_up_key = True 124 | if not self.unlocked_door and "You succeed in unlocking the door." in message: 125 | reward = .2 126 | self.unlocked_door = True 127 | if not self.used_item and any(x in message for x in [ 128 | "a ring of levitation (on right hand)", 129 | "You start to float in the air!", 130 | ]): 131 | reward = .3 132 | self.used_item = True 133 | if "end_status" in info and info["end_status"] == 2: 134 | reward = .4 135 | return obs, reward, done, info 136 | 137 | 138 | register( 139 | id="MiniHack-KeyLavaCross-v0", 140 | entry_point="sso.env.nethack.maps:MiniHackLavacross", 141 | ) 142 | 143 | register( 144 | id="MiniHack-KeyLavaCross2-v0", 145 | entry_point="sso.env.nethack.maps:MiniHackLavacross2", 146 | ) 147 | -------------------------------------------------------------------------------- /static/js/bulma-slider.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default}); -------------------------------------------------------------------------------- /sso/trajectory.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import List, Union, Dict, Any 3 | from copy import deepcopy 4 | import numpy as np 5 | 6 | from sso.utils import get_feature_similarity 7 | 8 | 9 | class State: 10 | def __init__( 11 | self, 12 | observation: str = None, 13 | state_description: str = None, 14 | state_features: List[str] = None, 15 | last_generation: str = None, 16 | last_action: str = None, 17 | reward: float = 0, 18 | done: bool = False, 19 | action_prompt: str = None, 20 | all_templates: List[str] = None, 21 | task_description: str = None, 22 | skill_target: str = None 23 | ): 24 | self.observation = observation 25 | self.state_description = state_description 26 | self.state_features = state_features 27 | self.last_generation = last_generation 28 | self.last_action = last_action 29 | self.reward = reward 30 | self.done = done 31 | self.action_prompt = action_prompt 32 | self.all_templates = all_templates 33 | self.task_description = task_description 34 | self.skill_target = skill_target 35 | 36 | def __eq__(self, other: object) -> bool: 37 | return isinstance(other, State) and \ 38 | self.state_features == other.state_features and \ 39 | self.last_action == other.last_action and \ 40 | self.task_description == other.task_description 41 | 42 | def __hash__(self) -> int: 43 | return hash((tuple(self.state_features), self.last_action, self.task_description)) 44 | 45 | def to_dict(self) -> Dict[str, Any]: 46 | return dict( 47 | observation=self.observation, 48 | state_description=self.state_description, 49 | state_features=self.state_features, 50 | last_generation=self.last_generation, 51 | last_action=self.last_action, 52 | reward=self.reward, 53 | done=self.done, 54 | action_prompt=self.action_prompt, 55 | all_templates=self.all_templates, 56 | task_description=self.task_description, 57 | skill_target=self.skill_target 58 | ) 59 | 60 | @staticmethod 61 | def from_dict(data: Dict[str, Any]) -> State: 62 | return State( 63 | observation=data["observation"], 64 | state_description=data["state_description"], 65 | state_features=data["state_features"], 66 | last_generation=data["last_generation"], 67 | last_action=data["last_action"], 68 | reward=data["reward"], 69 | done=data["done"], 70 | action_prompt=data["action_prompt"], 71 | all_templates=data["all_templates"], 72 | task_description=data["task_description"], 73 | skill_target=data["skill_target"] 74 | ) 75 | 76 | def state_similarity(self, state: State) -> float: 77 | return get_feature_similarity(self.state_description, state.state_description) 78 | 79 | def action_similarity(self, state: State) -> float: 80 | if self.last_action is not None and state.last_action is not None: 81 | return get_feature_similarity(self.last_action, state.last_action) 82 | elif self.last_generation is None and state.last_generation is None: 83 | return 1 84 | else: 85 | return 0 86 | 87 | 88 | class Trajectory: 89 | def __init__(self, steps: List[State] = None, task_description: str = None): 90 | self.steps = steps if steps is not None else [] 91 | self.task_description = task_description 92 | 93 | def __len__(self) -> int: 94 | return len(self.steps) 95 | 96 | def __getitem__(self, index: int) -> State: 97 | return self.steps[index] 98 | 99 | def __eq__(self, other: object) -> bool: 100 | return isinstance(other, Trajectory) and self.steps == other.steps and \ 101 | self.task_description == other.task_description 102 | 103 | def __hash__(self): 104 | return hash(tuple(self.steps + [self.task_description])) 105 | 106 | def to_dict(self) -> Dict[str, Any]: 107 | return dict( 108 | task_description=self.task_description, 109 | steps=[x.to_dict() for x in self.steps] 110 | ) 111 | 112 | @staticmethod 113 | def from_dict(data: Dict[str, Any]) -> Trajectory: 114 | return Trajectory( 115 | steps=[State.from_dict(x) for x in data["steps"]], 116 | task_description=data["task_description"] 117 | ) 118 | 119 | def slice(self, start: int, end: int) -> Trajectory: 120 | sliced = Trajectory( 121 | deepcopy(self.steps[start:end]), 122 | task_description=self.task_description 123 | ) 124 | return sliced 125 | 126 | def insert(self, state: State): 127 | self.steps.append(state) 128 | 129 | def build_string( 130 | self, 131 | use_task: bool = True, 132 | use_features: bool = True, 133 | use_actions: bool = True, 134 | start_idx: int = 0, 135 | end_idx: int = None, 136 | trim_state: bool = True, 137 | as_list: bool = False 138 | ) -> Union[str, List[str]]: 139 | 140 | return Trajectory.build_string_from_states( 141 | self.steps, 142 | self.task_description if use_task else None, 143 | use_features, 144 | use_actions, 145 | start_idx, 146 | end_idx, 147 | trim_state, 148 | as_list 149 | ) 150 | 151 | @staticmethod 152 | def build_string_from_states( 153 | states: List[State], 154 | task_description: str = None, 155 | use_features: bool = True, 156 | use_actions: bool = True, 157 | start_idx: int = 0, 158 | end_idx: int = None, 159 | trim_state: bool = True, 160 | as_list: bool = False 161 | ) -> Union[str, List[str]]: 162 | 163 | res = [] 164 | last_feats = [] 165 | next_step = "" 166 | if task_description: 167 | next_step += task_description + "\n\n" 168 | 169 | for i, step in enumerate(states[start_idx:end_idx]): 170 | next_step += "Step #{}:\n".format(i + 1) 171 | 172 | next_step += step.observation 173 | if trim_state: 174 | if use_features and "you see" not in step.observation.lower(): 175 | new_feats = [] 176 | for feat in sorted(step.state_features, key=lambda x: len(x), reverse=True): 177 | if feat not in last_feats and feat != step.observation \ 178 | and not any(feat.replace(" in your inventory", "") in x for x in new_feats): 179 | new_feats.append(feat) 180 | if len(new_feats) > 0: 181 | next_step += "\nYou also observe:\n\t" 182 | next_step += "\n\t".join(new_feats) 183 | last_feats = step.state_features 184 | elif use_features and "you see" not in step.observation.lower(): 185 | next_step += "\n" + step.state_description 186 | 187 | if use_actions and i + 1 < len(states): 188 | next_step += "\nYou choose to: {}".format(states[i + 1].last_action) 189 | 190 | res.append(next_step) 191 | next_step = "" 192 | 193 | return res if as_list else "\n\n".join(res) 194 | 195 | def similarity(self, trajectories: Union[Trajectory, List[Trajectory]], length_diff_penalty: float = 0.4) -> float: 196 | if isinstance(trajectories, Trajectory): 197 | trajectories = [trajectories] 198 | 199 | state_similarity = 1 200 | action_similarity = 1 201 | for trajectory in trajectories: 202 | assert len(self) == len(trajectory) 203 | 204 | state_similarity = min( 205 | state_similarity, 206 | np.mean([self[i].state_similarity(trajectory[i]) for i in range(len(self))]) 207 | ) 208 | action_similarity = min( 209 | action_similarity, 210 | np.mean([self[i].action_similarity(trajectory[i]) for i in range(1, len(self))]) 211 | ) 212 | 213 | return state_similarity, action_similarity 214 | -------------------------------------------------------------------------------- /sso/skill.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import List, Dict, Any 3 | 4 | import numpy as np 5 | from copy import deepcopy 6 | 7 | from sso.llm import query_llm 8 | from sso.trajectory import Trajectory 9 | 10 | 11 | class Skill: 12 | 13 | def __init__(self, instructions: List[str] = None, target: str = None, max_trajectories: int = 5): 14 | self.instructions = instructions 15 | self.target = target 16 | self.max_trajectories = max_trajectories 17 | self.trajectories: List[Trajectory] = [] 18 | self.traj_indices: List[int] = [] 19 | self.start_indices: List[int] = [] 20 | self.end_indices: List[int] = [] 21 | self.prereqs = None 22 | self.state_similarity = 1 23 | self.action_similarity = 1 24 | 25 | def __str__(self): 26 | return "Trajectory Count: {}\nTarget: {}\nInstructions:\n\t{}".format( 27 | len(self.trajectories), 28 | self.target, 29 | "\n\t".join(self.instructions) 30 | ) 31 | 32 | def __eq__(self, other: object): 33 | return isinstance(other, Skill) and \ 34 | set(self.trajectories[-self.max_trajectories:]) == set(other.trajectories[-self.max_trajectories:]) 35 | 36 | def __hash__(self): 37 | return hash(tuple(set(self.trajectories[-self.max_trajectories:]))) 38 | 39 | def traj_len(self): 40 | return min(len(x) for x in self.trajectories) 41 | 42 | def traj_count(self): 43 | return len(self.trajectories) 44 | 45 | def step_count(self): 46 | return sum(len(x) - 1 for x in self.trajectories) 47 | 48 | def reward(self) -> float: 49 | return np.mean([np.sum([x.reward for x in traj]) for traj in self.trajectories]) 50 | 51 | @staticmethod 52 | def from_dict(data: Dict[str, Any]): 53 | skill = Skill() 54 | skill.trajectories = [Trajectory.from_dict(x) for x in data["trajectories"]] 55 | skill.instructions = data["instructions"] 56 | skill.target = data["target"] 57 | skill.prereqs = data["prereqs"] 58 | skill.state_similarity = data["state_similarity"] 59 | skill.action_similarity = data["action_similarity"] 60 | skill.traj_indices = data["traj_indices"] 61 | skill.start_indices = data["start_indices"] 62 | skill.end_indices = data["end_indices"] 63 | return skill 64 | 65 | def to_dict(self) -> Dict[str, Any]: 66 | return dict( 67 | instructions=self.instructions, 68 | target=self.target, 69 | prereqs=self.prereqs, 70 | state_similarity=self.state_similarity, 71 | action_similarity=self.action_similarity, 72 | traj_indices=self.traj_indices, 73 | start_indices=self.start_indices, 74 | end_indices=self.end_indices, 75 | trajectories=[x.to_dict() for x in self.trajectories] 76 | ) 77 | 78 | def info(self) -> Dict[str, Any]: 79 | return dict( 80 | instructions=self.instructions, 81 | target=self.target, 82 | prereqs=self.prereqs, 83 | state_similarity=self.state_similarity, 84 | action_similarity=self.action_similarity, 85 | rewards=[float(np.sum([x.reward for x in traj])) for traj in self.trajectories], 86 | actions=[[step.last_action for step in traj[1:]] for traj in self.trajectories], 87 | traj_indices=self.traj_indices, 88 | start_indices=self.start_indices, 89 | end_indices=self.end_indices, 90 | ) 91 | 92 | def try_add(self, trajectory: Trajectory, traj_idx: int, start_idx: int, end_idx: int) -> bool: 93 | if self.traj_count() <= 1: 94 | self.add(trajectory, traj_idx, start_idx, end_idx) 95 | return True 96 | if trajectory not in self.trajectories: 97 | state_sim, action_sim = trajectory.similarity(self.trajectories) 98 | if state_sim >= self.state_similarity and action_sim >= self.action_similarity: 99 | self.add(trajectory, traj_idx, start_idx, end_idx) 100 | return True 101 | return False 102 | 103 | def add(self, trajectory: Trajectory, traj_idx: int, start_idx: int, end_idx: int) -> Skill: 104 | if trajectory not in self.trajectories: 105 | if len(self.trajectories) > 0: 106 | state_sim, action_sim = trajectory.similarity(self.trajectories) 107 | self.state_similarity = min(self.state_similarity, state_sim) 108 | self.action_similarity = min(self.action_similarity, action_sim) 109 | self.trajectories.append(trajectory) 110 | self.traj_indices.append(traj_idx) 111 | self.start_indices.append(start_idx) 112 | self.end_indices.append(end_idx) 113 | self.instructions = None 114 | self.target = None 115 | self.prereqs = None 116 | return self 117 | 118 | def is_compatible(self, other: Skill) -> bool: 119 | shared = [ 120 | (self.traj_indices.index(x), other.traj_indices.index(x)) 121 | for x in self.traj_indices 122 | if x in other.traj_indices 123 | ] 124 | for (this_idx, other_idx) in shared: 125 | if range( 126 | max(self.start_indices[this_idx], other.start_indices[other_idx]), 127 | min(self.end_indices[this_idx], other.end_indices[other_idx]) + 1 128 | ): 129 | return False 130 | return True 131 | 132 | def _generate(self) -> str: 133 | system_message = "You are an expert planning system. You are creating reusable skills to execute when completing various tasks. You create skills by looking at successful examples of task completions. A skill is composed of a list of instructions and a target state. After creating a skill, it will be used to execute actions in an environment. The environment will return a set of observations that summarize the new environment state. These observations will be used in conjunction with the skill's target state to determine whether the last skill was successful." 134 | 135 | summary_prompt = "Consider the example trajectories of states and actions below. You'll be asked to analyze the similarities between each. Pay attention to the wording of the state observations and actions. Then you'll be asked to generate the common instructions, and target state for them." 136 | for t, traj in enumerate(self.trajectories[-self.max_trajectories:]): 137 | summary_prompt += "\n\nExample {}:".format(t + 1) 138 | summary_prompt += "\n\nInitial State: " 139 | summary_prompt += traj[0].state_description 140 | summary_prompt += "\n\nTrajectory:" 141 | for s, step in enumerate(traj[1:]): 142 | summary_prompt += "\nAction {}: {}".format(s, step.last_action) 143 | summary_prompt += "\nObservation {}: {}".format(s, step.observation) 144 | summary_prompt += "\n\nFinal State: " 145 | summary_prompt += traj[-1].state_description 146 | summary_prompt += "\n\nGenerate a summary of what is happening in the examples above and the similarities between them. Provide a name for the skill that is being executed in the examples above. Do not generate skill instructions or target yet." 147 | messages = [dict(role="system", content=system_message), dict(role="user", content=summary_prompt)] 148 | summary_response = query_llm(messages, temperature=0.7) 149 | 150 | instruction_prompt = "Generate a numbered list of instructions for completing the skill. The instructions should be similar to the actions in the examples. Instructions should use the action templates provided below. Create generic instructions that would be valid for every example but specific enough to be useful in the examples. Do not mention the examples in the instructions. Use the output format:\nSkill [skill name] instructions:\n1. instruction 1\n2. instruction 2\n..." 151 | instruction_prompt += "\n\nAction templates: {}".format(", ".join(self.trajectories[0][0].all_templates)) 152 | messages += [dict(role="assistant", content=summary_response), dict(role="user", content=instruction_prompt)] 153 | instruction_response = query_llm(messages, temperature=0.7) 154 | 155 | target_prompt = "Generate a single target observation that would indicate the success of the skill. The target should be similar to one of the observations in the final states. Create a generic target that would be valid for every example. Do not mention the examples in the target. Use the output format:\nSkill [skill name] target: [target observation]" 156 | messages += [dict(role="user", content=target_prompt)] 157 | target_response = query_llm(messages, temperature=0.7) 158 | 159 | instructions = [] 160 | for x in instruction_response.lower().split("instructions:")[-1].split("\n"): 161 | if len(x.strip()) > 0: 162 | if x.strip()[0].isnumeric(): 163 | instructions.append(x.lstrip(" \t)-.1234567890")) 164 | elif len(instructions) > 0: 165 | instructions[-1] += " " + x.lstrip(" \t)-.1234567890") 166 | target = target_response.lower().replace("target observation", "").split("target:")[-1].strip(".-:[]\n\t ") 167 | return [], instructions, target 168 | 169 | def build(self, force: bool = False) -> bool: 170 | self.trajectories = [deepcopy(traj) for traj in self.trajectories] 171 | self.trajectories = sorted(self.trajectories, key=lambda x: len(x)) 172 | 173 | if force or self.instructions is None or self.target is None or self.prereqs is None: 174 | self.prereqs, self.instructions, self.target = self._generate() 175 | -------------------------------------------------------------------------------- /sso/env/scienceworld.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Dict, Any 2 | import random 3 | from scienceworld import ScienceWorldEnv 4 | 5 | from sso.env import Env 6 | from sso.trajectory import State 7 | from sso.utils import get_similar_action 8 | 9 | 10 | class ScienceWorld(Env): 11 | 12 | ACTION_TEMPLATES = [ 13 | "read OBJ", 14 | "activate OBJ", 15 | "deactivate OBJ", 16 | "open OBJ", 17 | "close OBJ", 18 | "pick up OBJ", 19 | "look in OBJ", 20 | "focus on OBJ", 21 | "move OBJ to CONTAINER", 22 | "pour OBJ in CONTAINER", 23 | "mix CONTAINER", 24 | "teleport LOCATION", 25 | "go LOCATION", 26 | "wait" 27 | ] 28 | 29 | def __init__( 30 | self, 31 | task: str = None, 32 | train_variant_count: int = 10, 33 | test_variant_count: int = 10, 34 | train_variants: List[int] = None, 35 | test_variants: List[int] = None, 36 | max_repeat: int = 10, 37 | seed: int = 42 38 | ): 39 | self.env = ScienceWorldEnv() 40 | random.seed(seed) 41 | assert task is not None, "You must specify a task to load" 42 | self.env.load(task) 43 | 44 | self.train_variants = train_variants 45 | if self.train_variants is None: 46 | self.train_variants = [ 47 | x for x in self.env.getVariationsTrain() 48 | if test_variants is None or x not in test_variants 49 | ] 50 | self.train_variants = random.sample(self.train_variants, min(len(self.train_variants), train_variant_count)) 51 | 52 | self.test_variants = test_variants 53 | if self.test_variants is None: 54 | self.test_variants = [ 55 | x for x in self.env.getVariationsTest() 56 | if train_variants is None or x not in train_variants 57 | ] 58 | self.test_variants = random.sample(self.test_variants, min(len(self.test_variants), test_variant_count)) 59 | 60 | self.max_repeat = max_repeat 61 | self.task = task 62 | 63 | self.last_actions = [] 64 | self.t = 0 65 | self._max_steps = None 66 | self.variant = None 67 | self.next_test_idx = 0 68 | self.next_train_idx = 0 69 | self.last_focus = None 70 | self.last_obs = None 71 | 72 | def reset(self, task_id: str = None, test: bool = False, **kwargs) -> Tuple[State, Dict[str, Any]]: 73 | if task_id: 74 | self.variant = task_id.split("_")[1] 75 | self.variant = int(self.variant) 76 | elif test: 77 | self.variant = self.test_variants[self.next_test_idx] 78 | self.next_test_idx += 1 79 | if self.next_test_idx >= len(self.test_variants): 80 | self.next_test_idx = 0 81 | else: 82 | self.variant = self.train_variants[self.next_train_idx] 83 | self.next_train_idx += 1 84 | if self.next_train_idx >= len(self.train_variants): 85 | self.next_train_idx = 0 86 | self.env.load(self.task, self.variant, "easy", generateGoldPath=True) 87 | 88 | self.last_obs, info = self.env.reset(**kwargs) 89 | self.last_actions = [] 90 | self.t = 0 91 | self._max_steps = int(len(self.env.getGoldActionSequence()) * 1.5) 92 | self.last_focus = None 93 | info["invalid"] = "no known action matches that input" in self.last_obs.lower() 94 | info["obs"] = self.last_obs 95 | info["task_id"] = "{}_{}".format(self.task, self.variant) 96 | info["task_description"] = info["taskDesc"] 97 | info["success"] = False 98 | state = self.get_state(info, None, None, 0, False) 99 | return state, info 100 | 101 | def step(self, action: str, **kwargs) -> Tuple[State, float, bool, Dict[str, Any]]: 102 | parsed_action = self.parse_action(action) 103 | self.last_obs, reward, done, info = self.env.step(parsed_action, **kwargs) 104 | if "focus" in self.last_obs.lower(): 105 | self.last_focus = self.last_obs 106 | self.last_actions.append(parsed_action) 107 | info["obs"] = self.last_obs 108 | info["task_id"] = "{}_{}".format(self.task, self.variant) 109 | info["task_description"] = info["taskDesc"] 110 | info["success"] = info["score"] >= 100 111 | info["last_raw_action"] = action 112 | self.t += 1 113 | if self.t >= self.max_steps or len(self.last_actions) > self.max_repeat and \ 114 | all(x == self.last_actions[-1] for x in self.last_actions[-self.max_repeat:]): 115 | done = True 116 | reward /= 100 117 | state = self.get_state(info, action, parsed_action, reward, done) 118 | return state, reward, done, info 119 | 120 | def parse_action(self, action: str) -> str: 121 | action = action.lower() 122 | if "next action:" in action: 123 | action = action.split("next action:")[-1].strip().split("\n")[0].strip() 124 | parsed_action = None 125 | if "ambiguous request" in self.last_obs.lower(): 126 | nums = [c for c in action if c.isnumeric()] 127 | if len(nums) > 0: 128 | parsed_action = nums[0] 129 | if parsed_action is None: 130 | parsed_action = get_similar_action(action, self.env.getValidActionObjectCombinations()) 131 | if parsed_action is None: 132 | parsed_action = action 133 | return parsed_action 134 | 135 | def get_action_prompt(self): 136 | res = "Generate the action using one of the following templates, where OBJ is an object in the scene and LOCATION is an adjacent room and CONTAINER is an object that can contain or hold other objects such as a pot, table, or inventory." 137 | res += " The \"focus on OBJ\" action is extremely critical and should not be used to look at or shift your attention to a specific object. It should only be used as described in the task description. Using this action inappropriately will result in task failure." 138 | res += " The \"wait\" action is used to wait for time to pass during tasks that require the passage of time." 139 | res += "\nTemplates:\n\t" + "\n\t".join(self.ACTION_TEMPLATES) 140 | return res 141 | 142 | def get_state(self, state: Dict[str, Any], last_generation: str, last_action: str, reward: float, done: bool) -> State: 143 | return State( 144 | observation=state["obs"][:state["obs"].index("In it, you see")].strip() if "In it, you see" in state["obs"] else state["obs"], 145 | state_features=self.get_features(state, self.last_focus), 146 | state_description=self.get_description(state), 147 | last_generation=last_generation, 148 | last_action=last_action, 149 | reward=reward, 150 | done=done, 151 | task_description=state["taskDesc"], 152 | action_prompt=self.get_action_prompt(), 153 | all_templates=self.ACTION_TEMPLATES, 154 | ) 155 | 156 | @property 157 | def max_steps(self) -> int: 158 | return self._max_steps 159 | 160 | @property 161 | def num_train(self) -> int: 162 | return len(self.train_variants) 163 | 164 | @property 165 | def num_test(self) -> int: 166 | return len(self.test_variants) 167 | 168 | @property 169 | def train_ids(self) -> Tuple[str]: 170 | return tuple(["{}_{}".format(self.task, v) for v in self.train_variants]) 171 | 172 | @property 173 | def test_ids(self) -> Tuple[str]: 174 | return tuple(["{}_{}".format(self.task, v) for v in self.test_variants]) 175 | 176 | @staticmethod 177 | def parse_commas(sub_feats: str) -> List[str]: 178 | p = [] 179 | idx = 0 180 | next_feat = "" 181 | while idx < len(sub_feats): 182 | if sub_feats[idx] == ",": 183 | p.append(next_feat) 184 | next_feat = "" 185 | idx += 1 186 | elif sub_feats[idx:].startswith("(containing") and ")" in sub_feats[idx:]: 187 | p += ScienceWorld.parse_commas( 188 | sub_feats[idx + len("(containing") : idx + sub_feats[idx:].index(")")] 189 | ) 190 | paren_size = sub_feats[idx:].index(")") + 1 191 | next_feat += sub_feats[idx:idx+paren_size] 192 | p.append(next_feat) 193 | next_feat = "" 194 | idx += paren_size 195 | else: 196 | next_feat += sub_feats[idx] 197 | idx += 1 198 | p.append(next_feat) 199 | return p 200 | 201 | @staticmethod 202 | def parse_feature(feature: str) -> List[str]: 203 | parsed = [feature] 204 | if ":" in feature and ": which is" not in feature: 205 | parsed += ScienceWorld.parse_commas(feature[feature.index(":") + 1:]) 206 | elif "(containing" in feature: 207 | parsed += ScienceWorld.parse_commas(feature[feature.index("(containing") + len("(containing") : feature.index(")")]) 208 | return [p.strip(" \n\t.,").lower() for p in parsed if p.strip(" \n\t.,") not in ["", "nothing"]] 209 | 210 | @staticmethod 211 | def get_description(state: Dict[str, Any]) -> str: 212 | res = state["obs"] 213 | if "you see" not in res.lower(): 214 | res += "\n" + state["look"] 215 | if state["inv"] not in res: 216 | res += "\n" + state["inv"] 217 | return res 218 | 219 | @staticmethod 220 | def get_features(state: Dict[str, Any], last_focus=None) -> List[str]: 221 | features = [] 222 | if last_focus is not None: 223 | features.append(last_focus) 224 | 225 | for feat in state["look"].split("\n"): 226 | if feat.strip() != "" and feat != "You also see:": 227 | if feat.startswith("\t"): 228 | features.extend([x for x in ScienceWorld.parse_feature(feat)]) 229 | else: 230 | features.append(feat.replace("In it, you see:", "").strip().lower()) 231 | 232 | for feat in state["inv"].split("\n"): 233 | if feat.strip() != "" and feat != "In your inventory, you see:": 234 | features.extend([x + " in your inventory" for x in ScienceWorld.parse_feature(feat)]) 235 | 236 | return [x.replace("\n", ", ").replace("\t", " ") for x in features] 237 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, List 2 | from argparse import ArgumentParser 3 | from tqdm import tqdm 4 | import os 5 | import json 6 | import numpy as np 7 | 8 | from sso.env import Env 9 | from sso.env.scienceworld import ScienceWorld 10 | from sso.env.nethack.base import NetHackTask 11 | from sso.agent import Agent 12 | from sso.agent.skills import SkillsAgent 13 | from sso.agent.fewshot import FewshotAgent 14 | from sso.agent.reflexion import ReflexionAgent 15 | from sso.memory.skillset import SkillSetMemory 16 | from sso.memory.examples import ExamplesMemory 17 | from sso.trajectory import Trajectory 18 | from sso.llm import set_default_model 19 | 20 | 21 | def run_episode( 22 | env: Env, 23 | agent: Agent, 24 | task_id: str = None, 25 | test: bool = False, 26 | ) -> Tuple[str, Trajectory, bool, float]: 27 | 28 | done = False 29 | state, info = env.reset(task_id=task_id, test=test) 30 | trajectory = Trajectory(task_description=info["task_description"]) 31 | trajectory.insert(state) 32 | 33 | agent.reset_log() 34 | agent.log("task_id", info["task_id"]) 35 | agent.log("task_description", info["task_description"]) 36 | 37 | pbar = tqdm(total=env.max_steps) 38 | score = 0 39 | while not done: 40 | action = agent.act(trajectory) 41 | state, reward, done, info = env.step(action) 42 | if reward > 0: 43 | score += reward 44 | trajectory.insert(state) 45 | pbar.update(1) 46 | pbar.close() 47 | 48 | agent.record_done(trajectory) 49 | agent.log("trajectory_success", info["success"]) 50 | agent.log("score", score) 51 | agent.log("length", len(trajectory)) 52 | 53 | return info["task_id"], trajectory, info["success"], score 54 | 55 | 56 | def log_results( 57 | agent: Agent, 58 | trajectory: Trajectory, 59 | save_path: str, 60 | iteration: int, 61 | task_id: str, 62 | success: bool, 63 | score: float 64 | ): 65 | os.makedirs(save_path, exist_ok=True) 66 | with open(os.path.join(save_path, "logs.json"), "w") as f: 67 | json.dump(agent.get_log(), f, indent=4) 68 | with open(os.path.join(save_path, "trajectory.json"), "w") as f: 69 | json.dump(trajectory.to_dict(), f, indent=4) 70 | agent.save(save_path) 71 | print("Iter: {}, task: {}, success: {}, score: {}".format(iteration, task_id, success, score)) 72 | 73 | 74 | def run_experiment( 75 | env: Env, 76 | agent: Agent, 77 | results_dir: str, 78 | experiment_name: str, 79 | iteration: int, 80 | temp: float = 1, 81 | train: bool = True, 82 | use_test_tasks: bool = False 83 | ): 84 | agent.train(train) 85 | set_default_model(temp=temp) 86 | task_id, trajectory, success, score = run_episode(env, agent, test=use_test_tasks) 87 | save_path = os.path.join(results_dir, experiment_name, str(iteration)) 88 | log_results(agent, trajectory, save_path, iteration, task_id, success, score) 89 | return score, success 90 | 91 | 92 | def run( 93 | agent: Agent, 94 | env: Env, 95 | results_dir: str = "results", 96 | train_count: int = 5, 97 | adapt_count: int = 5, 98 | test_count: int = 1, 99 | test_freq: int = 0, 100 | train_temp: float = 1, 101 | test_temp: float = 0, 102 | test_init: bool = False, 103 | test_adapt: bool = False 104 | ) -> Tuple[Dict[str, List[bool]], Dict[str, List[bool]]]: 105 | 106 | def get_scores(experiment_name): 107 | s = [] 108 | for exp in os.listdir(os.path.join(results_dir, experiment_name)): 109 | with open(os.path.join(results_dir, experiment_name, exp, "logs.json"), "r") as f: 110 | s.append(json.load(f)[-1]["score"]) 111 | return s 112 | results = dict() 113 | 114 | # Test base agent 115 | if test_init and (train_count == 0 or train_count > 0): 116 | print("\n##############################\nTesting base agent") 117 | for test_iter in range(test_count * env.num_test): 118 | run_experiment(env, agent, results_dir, "test_init", test_iter, temp=test_temp, train=False, use_test_tasks=True) 119 | if test_count > 0: 120 | results["base"] = np.mean(get_scores("test_init")) 121 | with open(os.path.join(results_dir, "results.json"), "w") as f: 122 | json.dump(results, f, indent=4) 123 | 124 | # Adapt 125 | if test_adapt: 126 | print("\n##############################\nTesting base adaptation") 127 | for test_id in env.test_ids: 128 | results["adapt_" + test_id] = [] 129 | agent.clear() 130 | for train_iter in range(adapt_count): 131 | score, _ = run_experiment(env, agent, results_dir, "test_adapt/{}/{}".format(test_id, train_iter), 132 | "train", temp=test_temp, train=True, use_test_tasks=True) 133 | results["adapt_" + test_id].append(score) 134 | results["adapt_best"] = np.mean([np.max(results["adapt_" + test_id]) for test_id in env.test_ids]) 135 | with open(os.path.join(results_dir, "results.json"), "w") as f: 136 | json.dump(results, f, indent=4) 137 | agent.clear() 138 | 139 | # Train agent 140 | print("\n##############################\nTraining agent") 141 | train_scores = [] 142 | for train_iter in range(train_count * env.num_train): 143 | score = run_experiment(env, agent, results_dir, "train", train_iter, temp=train_temp) 144 | train_scores.append(score) 145 | if test_freq > 0 and (train_iter + 1) % test_freq == 0 and train_iter < (train_count * env.num_train) - 1: 146 | for test_iter in range(test_count * env.num_test): 147 | run_experiment(env, agent, results_dir, "test_iter{}".format(train_iter), test_iter, temp=test_temp, 148 | train=False, use_test_tasks=True) 149 | results["test_iter{}".format(train_iter)] = np.mean(get_scores("test_iter{}".format(train_iter))) 150 | with open(os.path.join(results_dir, "results.json"), "w") as f: 151 | json.dump(results, f, indent=4) 152 | 153 | # Transfer 154 | print("\n##############################\nTesting transfer agent") 155 | if train_count * env.num_train > env.num_test or not os.path.exists(os.path.join(results_dir, "test_transfer")): 156 | for test_iter in range(test_count * env.num_test): 157 | run_experiment(env, agent, results_dir, "test_transfer", test_iter, temp=test_temp, train=False, use_test_tasks=True) 158 | if test_count > 0: 159 | results["transfer"] = np.mean(get_scores("test_transfer")) 160 | with open(os.path.join(results_dir, "results.json"), "w") as f: 161 | json.dump(results, f, indent=4) 162 | 163 | 164 | if __name__ == '__main__': 165 | 166 | parser = ArgumentParser() 167 | 168 | # Experiment params 169 | parser.add_argument("--output", type=str, default="results", help="output directory") 170 | parser.add_argument("--train_iters", type=int, default=3, help="number of iterations to run") 171 | parser.add_argument("--test_iters", type=int, default=1, help="number of test iterations to run") 172 | parser.add_argument("--test_freq", type=int, default=0, help="test frequency, if 0 will test once at end") 173 | parser.add_argument("--test_init", action="store_true", help="test initial agent") 174 | parser.add_argument("--test_adapt", action="store_true", help="test adapting to test tasks") 175 | 176 | # Agent params 177 | parser.add_argument("--agent", type=str, default="skills", help="agent type") 178 | parser.add_argument("--load", type=str, default=None, help="directory to load agent from") 179 | parser.add_argument("--model", type=str, default="gpt-4-0613", help="model name") 180 | parser.add_argument("--train_temp", type=float, default=0.7, help="Generation temperature for the llm during training") 181 | parser.add_argument("--test_temp", type=float, default=0, help="Generation temperature for the llm during testing") 182 | parser.add_argument("--max_history", type=int, default=10, help="number of past steps to keep in history") 183 | parser.add_argument("--full_states", action="store_true", help="do not trim states to only keep new information") 184 | parser.add_argument("--similarity_metric", type=str, default="text-embedding-ada-002", help="similarity metric to use, iou or model name") 185 | 186 | # Env params 187 | parser.add_argument("--env", type=str, default="nethack", help="environment type") 188 | parser.add_argument("--task", type=str, default="MiniHack-KeyLavaCross-v0", help="task to run") 189 | parser.add_argument("--train_variant_count", type=int, default=10, help="number of variants to train on") 190 | parser.add_argument("--test_variant_count", type=int, default=10, help="number of variants to test on") 191 | parser.add_argument("--train_variants", type=int, nargs="+", default=None, help="train task variants, if None will use default task split") 192 | parser.add_argument("--test_variants", type=int, nargs="+", default=None, help="test task variants, if None will use default task split") 193 | 194 | # Memory params 195 | parser.add_argument("--memory", type=str, default="skills", help="memory type") 196 | parser.add_argument("--reward_weight", type=float, default=0.1, help="weight for trajectory reward in skill score") 197 | parser.add_argument("--state_weight", type=float, default=1, help="weight for state similarity in skill score") 198 | parser.add_argument("--action_weight", type=float, default=1, help="weight for action similarity in skill score") 199 | parser.add_argument("--coverage_weight", type=float, default=0.01, help="weight for task coverage in skill score") 200 | 201 | args = parser.parse_args() 202 | 203 | # Set LLMs 204 | set_default_model(model=args.model, temp=args.train_temp, 205 | embedding=None if args.similarity_metric == "iou" else args.similarity_metric) 206 | 207 | # Set memory 208 | if args.memory == "skills": 209 | memory = SkillSetMemory( 210 | reward_weight=args.reward_weight, 211 | state_weight=args.state_weight, 212 | action_weight=args.action_weight, 213 | coverage_weight=args.coverage_weight, 214 | ) 215 | elif args.memory == "examples": 216 | memory = ExamplesMemory() 217 | else: 218 | raise ValueError(f"Unknown memory type: {args.memory}") 219 | 220 | # Set agent 221 | agent_args = dict( 222 | memory=memory, 223 | max_history=args.max_history, 224 | trim_states=not args.full_states 225 | ) 226 | if args.agent == "skills": 227 | agent = SkillsAgent(**agent_args) 228 | elif args.agent == "fewshot": 229 | agent = FewshotAgent(**agent_args) 230 | elif args.agent == "reflexion": 231 | agent = ReflexionAgent(**agent_args) 232 | else: 233 | raise ValueError("Invalid agent: {}".format(args.agent)) 234 | if args.load is not None: 235 | agent.load(args.load) 236 | 237 | # Set env 238 | env_args = dict( 239 | task=args.task, 240 | train_variant_count=args.train_variant_count, 241 | test_variant_count=args.test_variant_count, 242 | train_variants=args.train_variants, 243 | test_variants=args.test_variants 244 | ) 245 | if args.env == "scienceworld": 246 | env = ScienceWorld(**env_args) 247 | elif args.env == "nethack": 248 | env = NetHackTask(**env_args) 249 | else: 250 | raise ValueError("Invalid env: {}".format(args.env)) 251 | 252 | # Run 253 | os.makedirs(args.output, exist_ok=True) 254 | with open(os.path.join(args.output, "args.json"), "w") as f: 255 | json.dump(vars(args), f, indent=4) 256 | run(agent, env, results_dir=args.output, train_count=args.train_iters, test_count=args.test_iters, 257 | test_freq=args.test_freq, train_temp=args.train_temp, test_temp=args.test_temp, 258 | test_init=args.test_init, test_adapt=args.test_adapt) 259 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /static/css/bulma-slider.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}input[type=range].slider{-webkit-appearance:none;-moz-appearance:none;appearance:none;margin:1rem 0;background:0 0;touch-action:none}input[type=range].slider.is-fullwidth{display:block;width:100%}input[type=range].slider:focus{outline:0}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{width:100%}input[type=range].slider:not([orient=vertical])::-moz-range-track{width:100%}input[type=range].slider:not([orient=vertical])::-ms-track{width:100%}input[type=range].slider:not([orient=vertical]).has-output+output,input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{width:3rem;background:#4a4a4a;border-radius:4px;padding:.4rem .8rem;font-size:.75rem;line-height:.75rem;text-align:center;text-overflow:ellipsis;white-space:nowrap;color:#fff;overflow:hidden;pointer-events:none;z-index:200}input[type=range].slider:not([orient=vertical]).has-output-tooltip:disabled+output,input[type=range].slider:not([orient=vertical]).has-output:disabled+output{opacity:.5}input[type=range].slider:not([orient=vertical]).has-output{display:inline-block;vertical-align:middle;width:calc(100% - (4.2rem))}input[type=range].slider:not([orient=vertical]).has-output+output{display:inline-block;margin-left:.75rem;vertical-align:middle}input[type=range].slider:not([orient=vertical]).has-output-tooltip{display:block}input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{position:absolute;left:0;top:-.1rem}input[type=range].slider[orient=vertical]{-webkit-appearance:slider-vertical;-moz-appearance:slider-vertical;appearance:slider-vertical;-webkit-writing-mode:bt-lr;-ms-writing-mode:bt-lr;writing-mode:bt-lr}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{height:100%}input[type=range].slider[orient=vertical]::-moz-range-track{height:100%}input[type=range].slider[orient=vertical]::-ms-track{height:100%}input[type=range].slider::-webkit-slider-runnable-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-moz-range-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-fill-lower{background:#dbdbdb;border-radius:4px}input[type=range].slider::-ms-fill-upper{background:#dbdbdb;border-radius:4px}input[type=range].slider::-webkit-slider-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-moz-range-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-ms-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-webkit-slider-thumb{-webkit-appearance:none;appearance:none}input[type=range].slider.is-circle::-webkit-slider-thumb{border-radius:290486px}input[type=range].slider.is-circle::-moz-range-thumb{border-radius:290486px}input[type=range].slider.is-circle::-ms-thumb{border-radius:290486px}input[type=range].slider:active::-webkit-slider-thumb{-webkit-transform:scale(1.25);transform:scale(1.25)}input[type=range].slider:active::-moz-range-thumb{transform:scale(1.25)}input[type=range].slider:active::-ms-thumb{transform:scale(1.25)}input[type=range].slider:disabled{opacity:.5;cursor:not-allowed}input[type=range].slider:disabled::-webkit-slider-thumb{cursor:not-allowed;-webkit-transform:scale(1);transform:scale(1)}input[type=range].slider:disabled::-moz-range-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:disabled::-ms-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:not([orient=vertical]){min-height:calc((1rem + 2px) * 1.25)}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-moz-range-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-ms-track{height:.5rem}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{width:.5rem}input[type=range].slider[orient=vertical]::-moz-range-track{width:.5rem}input[type=range].slider[orient=vertical]::-ms-track{width:.5rem}input[type=range].slider::-webkit-slider-thumb{height:1rem;width:1rem}input[type=range].slider::-moz-range-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{margin-top:0}input[type=range].slider::-webkit-slider-thumb{margin-top:-.25rem}input[type=range].slider[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.25rem}input[type=range].slider.is-small:not([orient=vertical]){min-height:calc((.75rem + 2px) * 1.25)}input[type=range].slider.is-small:not([orient=vertical])::-webkit-slider-runnable-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-moz-range-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-ms-track{height:.375rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-runnable-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-moz-range-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-ms-track{width:.375rem}input[type=range].slider.is-small::-webkit-slider-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-moz-range-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{margin-top:0}input[type=range].slider.is-small::-webkit-slider-thumb{margin-top:-.1875rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.1875rem}input[type=range].slider.is-medium:not([orient=vertical]){min-height:calc((1.25rem + 2px) * 1.25)}input[type=range].slider.is-medium:not([orient=vertical])::-webkit-slider-runnable-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-moz-range-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-ms-track{height:.625rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-runnable-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-moz-range-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-ms-track{width:.625rem}input[type=range].slider.is-medium::-webkit-slider-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-moz-range-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{margin-top:0}input[type=range].slider.is-medium::-webkit-slider-thumb{margin-top:-.3125rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.3125rem}input[type=range].slider.is-large:not([orient=vertical]){min-height:calc((1.5rem + 2px) * 1.25)}input[type=range].slider.is-large:not([orient=vertical])::-webkit-slider-runnable-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-moz-range-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-ms-track{height:.75rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-runnable-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-moz-range-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-ms-track{width:.75rem}input[type=range].slider.is-large::-webkit-slider-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-moz-range-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{margin-top:0}input[type=range].slider.is-large::-webkit-slider-thumb{margin-top:-.375rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.375rem}input[type=range].slider.is-white::-moz-range-track{background:#fff!important}input[type=range].slider.is-white::-webkit-slider-runnable-track{background:#fff!important}input[type=range].slider.is-white::-ms-track{background:#fff!important}input[type=range].slider.is-white::-ms-fill-lower{background:#fff}input[type=range].slider.is-white::-ms-fill-upper{background:#fff}input[type=range].slider.is-white .has-output-tooltip+output,input[type=range].slider.is-white.has-output+output{background-color:#fff;color:#0a0a0a}input[type=range].slider.is-black::-moz-range-track{background:#0a0a0a!important}input[type=range].slider.is-black::-webkit-slider-runnable-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-fill-lower{background:#0a0a0a}input[type=range].slider.is-black::-ms-fill-upper{background:#0a0a0a}input[type=range].slider.is-black .has-output-tooltip+output,input[type=range].slider.is-black.has-output+output{background-color:#0a0a0a;color:#fff}input[type=range].slider.is-light::-moz-range-track{background:#f5f5f5!important}input[type=range].slider.is-light::-webkit-slider-runnable-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-fill-lower{background:#f5f5f5}input[type=range].slider.is-light::-ms-fill-upper{background:#f5f5f5}input[type=range].slider.is-light .has-output-tooltip+output,input[type=range].slider.is-light.has-output+output{background-color:#f5f5f5;color:#363636}input[type=range].slider.is-dark::-moz-range-track{background:#363636!important}input[type=range].slider.is-dark::-webkit-slider-runnable-track{background:#363636!important}input[type=range].slider.is-dark::-ms-track{background:#363636!important}input[type=range].slider.is-dark::-ms-fill-lower{background:#363636}input[type=range].slider.is-dark::-ms-fill-upper{background:#363636}input[type=range].slider.is-dark .has-output-tooltip+output,input[type=range].slider.is-dark.has-output+output{background-color:#363636;color:#f5f5f5}input[type=range].slider.is-primary::-moz-range-track{background:#00d1b2!important}input[type=range].slider.is-primary::-webkit-slider-runnable-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-fill-lower{background:#00d1b2}input[type=range].slider.is-primary::-ms-fill-upper{background:#00d1b2}input[type=range].slider.is-primary .has-output-tooltip+output,input[type=range].slider.is-primary.has-output+output{background-color:#00d1b2;color:#fff}input[type=range].slider.is-link::-moz-range-track{background:#3273dc!important}input[type=range].slider.is-link::-webkit-slider-runnable-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-fill-lower{background:#3273dc}input[type=range].slider.is-link::-ms-fill-upper{background:#3273dc}input[type=range].slider.is-link .has-output-tooltip+output,input[type=range].slider.is-link.has-output+output{background-color:#3273dc;color:#fff}input[type=range].slider.is-info::-moz-range-track{background:#209cee!important}input[type=range].slider.is-info::-webkit-slider-runnable-track{background:#209cee!important}input[type=range].slider.is-info::-ms-track{background:#209cee!important}input[type=range].slider.is-info::-ms-fill-lower{background:#209cee}input[type=range].slider.is-info::-ms-fill-upper{background:#209cee}input[type=range].slider.is-info .has-output-tooltip+output,input[type=range].slider.is-info.has-output+output{background-color:#209cee;color:#fff}input[type=range].slider.is-success::-moz-range-track{background:#23d160!important}input[type=range].slider.is-success::-webkit-slider-runnable-track{background:#23d160!important}input[type=range].slider.is-success::-ms-track{background:#23d160!important}input[type=range].slider.is-success::-ms-fill-lower{background:#23d160}input[type=range].slider.is-success::-ms-fill-upper{background:#23d160}input[type=range].slider.is-success .has-output-tooltip+output,input[type=range].slider.is-success.has-output+output{background-color:#23d160;color:#fff}input[type=range].slider.is-warning::-moz-range-track{background:#ffdd57!important}input[type=range].slider.is-warning::-webkit-slider-runnable-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-fill-lower{background:#ffdd57}input[type=range].slider.is-warning::-ms-fill-upper{background:#ffdd57}input[type=range].slider.is-warning .has-output-tooltip+output,input[type=range].slider.is-warning.has-output+output{background-color:#ffdd57;color:rgba(0,0,0,.7)}input[type=range].slider.is-danger::-moz-range-track{background:#ff3860!important}input[type=range].slider.is-danger::-webkit-slider-runnable-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-fill-lower{background:#ff3860}input[type=range].slider.is-danger::-ms-fill-upper{background:#ff3860}input[type=range].slider.is-danger .has-output-tooltip+output,input[type=range].slider.is-danger.has-output+output{background-color:#ff3860;color:#fff} -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | 8 | 9 | SSO - Allen Institute for AI 10 | 11 | 13 | 14 | 15 | 16 | 17 | 18 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 |
33 |
34 |
35 |
36 |
37 |

Skill Set Optimization: Reinforcing Language Model Behavior via Transferable Skills

38 |
39 | 40 | Kolby Nottingham1, 41 | 42 | 43 | Bodhisattwa Prasad Majumder2*, 44 | 45 | 46 | Bhavana Dalvi Mishra2*, 47 | 48 |
49 | 50 | Sameer Singh1, 51 | 52 | 53 | Peter Clark2 54 | 55 | 56 | Roy Fox1 57 | 58 |
59 | 60 |
61 | 1University of California Irvine, 62 | 2Allen Institute for AI 63 |
64 | *Equal Contribution 65 |
66 | 67 |
68 | 118 | 119 |
120 |
121 |
122 |
123 |
124 |
125 | 126 |
127 |
128 |
129 | 130 |
131 |

132 | Continual learning for LLM actors via discovering and reinforcing in-context skills 133 |

134 |
135 |
136 |
137 | 138 | 139 |
140 |
141 | 142 |
143 |
144 |

Abstract

145 |
146 |

147 | Large language models (LLMs) have recently been used for sequential decision making in interactive environments. However, leveraging environment reward signals for continual LLM actor improvement is not straightforward. We propose Skill Set Optimization (SSO) for improving LLM actor performance through constructing and refining sets of transferable skills. SSO constructs skills by extracting common subtrajectories with high rewards and generating subgoals and instructions to represent each skill. These skills are provided to the LLM actor in-context to reinforce behaviors with high rewards. Then, SSO further refines the skill set by pruning skills that do not continue to result in high rewards. We evaluate our method in the classic videogame NetHack and the text environment ScienceWorld to demonstrate SSO's ability to optimize a set of skills and perform in-context policy improvement. SSO outperforms baselines by 40% in our custom NetHack task and outperforms the previous state-of-the-art in ScienceWorld by 35%. 148 |

149 |

150 |

151 |
152 |
153 |
154 | 155 | 156 |
157 |
158 |
159 |
160 | 161 |

In-Context Policy Improvement

162 | 163 |
164 |

165 | Like other continual learning methods*, SSO uses in-context "memories" with information about the task and environment to improve the LLM actor's policy. The memories that SSO generates are instructions for achieving subgoals we call skills. Unlike previous work, SSO continuously evaluates generated memories, creates memories that define modular subgoals, and facilitates memory retrieval. 166 |

167 |

168 | * e.g. Voyager, ExpeL, and CLIN agents 169 |

170 | 171 |

Skill Set Optimization

172 | 173 |
174 |

175 | Each iteration of SSO includes: 176 |

177 |
    178 |
  1. Rolling out a single trajectory with the LLM actor and current skill set
  2. 179 |
  3. Constructing new skills
  4. 180 |
  5. Refining executed skills
  6. 181 |
182 |

183 | To construct new skills, we extract potential subtrajectories, score them using discounted reward and similarity and length, sample an updated skill set using beam search, and generate subgoals and instructions for each new skill. We refine the constructed skill set by filtering skills that did not result in high rewards when used previous trajectories. Then, when providing skills in-context, we retrieve only the most relevant skills based on cosine similarity of skill initial states and the current environment state. 184 |

185 | 186 |

Skill Lifecycle

187 | 188 |
189 |
190 |

191 | Each row of this plot shows all of the skills created in the cooresponding iteration and when they were executed. On both ScienceWorld and NetHack, SSO prunes most new skills after few iterations. The LLM actor uses more recent skills as it continues to improve at the task and learn new skills and improve old skills. 192 |

193 | 194 |

State-of-that-art Results

195 | 196 |
197 |

198 | SSO outperforms previous state-of-the-art in ScienceWorld by 35% in task adaptation and 14% in task transfer. Learned and reinforced skills such as those listed below provide knowledge of subgoals that are transferable across tasks. 199 |

200 | 201 | 202 | 209 | 218 | 219 |
203 | You move to the kitchen 204 |
    205 |
  1. Go to the hallway
  2. 206 |
  3. Go to the kitchen
  4. 207 |
208 |
210 | The stove is turned on. on the stove is: a substance called liquid [substance] 211 |
    212 |
  1. focus on the thermometer
  2. 213 |
  3. focus on the substance you want to heat
  4. 214 |
  5. move the focused substance to the stove
  6. 215 |
  7. activate the stove
  8. 216 |
217 |
220 |
221 |
222 |
223 |
224 | 225 | 226 | 227 |
228 |
229 |

BibTeX

230 |
@article{nottingham2024sso,
231 |   author    = "Nottingham, Kolby and Majumder, Bodhisattwa Prasad and Dalvi Mishra, Bhavana and Singh, Sameer and Clark, Peter and Fox, Roy",
232 |   title     = "Skill Set Optimization: Reinforcing Language Model Behavior via Transferable Skills",
233 |   journal   = "arXiv",
234 |   year      = "2024",
235 |   url       = "https://arxiv.org/abs/2402.03244"
236 | }
237 |
238 |
239 | 240 | 241 | 262 | 263 | 264 | -------------------------------------------------------------------------------- /sso/memory/skillset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Tuple 2 | 3 | import numpy as np 4 | from tqdm import tqdm 5 | from copy import deepcopy 6 | from functools import lru_cache 7 | 8 | from sso.trajectory import Trajectory 9 | from sso.skill import Skill 10 | from sso.memory import Memory 11 | from sso.llm import query_llm 12 | from sso.utils import get_state_similarity 13 | 14 | 15 | class SkillSetMemory(Memory): 16 | 17 | def __init__( 18 | self, 19 | max_trajectories: int = 10, 20 | num_beams: int = 20, 21 | max_skills: int = 5000, 22 | max_traj_len: int = 50, 23 | min_skill_len: int = 3, # these numbers include the start and end states 24 | max_skill_len: int = 6, 25 | max_traj_count: int = 2, 26 | coverage_weight: float = .01, 27 | reward_weight: float = .01, 28 | state_weight: float = .1, 29 | action_weight: float = 1, 30 | sampled_weight: float = 1, 31 | **kwargs 32 | ): 33 | super().__init__(max_trajectories=max_trajectories, **kwargs) 34 | self.num_beams = num_beams 35 | self.max_traj_len = max_traj_len 36 | self.max_skills = max_skills 37 | self.min_skill_len = min_skill_len 38 | self.max_skill_len = max_skill_len 39 | self.max_traj_count = max_traj_count 40 | 41 | self.coverage_weight = coverage_weight 42 | self.reward_weight = reward_weight 43 | self.state_weight = state_weight 44 | self.action_weight = action_weight 45 | self.sampled_weight = sampled_weight 46 | 47 | self.last_skills = [] 48 | self.built_trajectories = set() 49 | self.mean_coverage = np.mean(np.arange(self.min_skill_len, self.max_skill_len + 1)) 50 | self.std_coverage = np.std(np.arange(self.min_skill_len, self.max_skill_len + 1)) 51 | self.mean_reward = 0 52 | self.std_reward = 1 53 | self.all_state_similarities = [] 54 | self.mean_state_similarity = 0 55 | self.std_state_similarity = 1 56 | self.all_action_similarities = [] 57 | self.mean_action_similarity = 0 58 | self.std_action_similarity = 1 59 | 60 | def insert(self, trajectory: Trajectory) -> None: 61 | 62 | # Split trajectories into successful and unsuccessful 63 | inserted = None 64 | idx = 0 65 | while self.max_traj_len is not None and idx + self.max_traj_len < len(trajectory): 66 | sub_traj = trajectory.slice(idx, idx + self.max_traj_len) 67 | last_reward = [i for i, x in enumerate(sub_traj) if x.reward > 0] 68 | if len(last_reward) > 0 and last_reward[-1] >= self.min_skill_len: 69 | self.trajectories.append(sub_traj.slice(0, last_reward[-1] + 1)) 70 | idx += last_reward[-1] 71 | idx += 1 72 | last_reward = [i for i, x in enumerate(trajectory) if x.reward > 0] 73 | if len(last_reward) > 0 and last_reward[-1] - idx >= self.min_skill_len: 74 | self.trajectories.append(trajectory.slice(idx, last_reward[-1] + 1)) 75 | inserted = trajectory.slice(0, last_reward[-1] + 1) 76 | 77 | # Record skill feedback 78 | successful_cutoff = len(inserted) if inserted is not None else 0 79 | self.skill_feedback.append([]) 80 | for t, skill in self.sampled: 81 | ret = sum(step.reward * self.discount ** i for i, step in enumerate(trajectory[t:])) 82 | self.skill_feedback[-1].append((skill, ret, t < successful_cutoff)) 83 | while len(self.skill_feedback) > self.max_trajectories * self.skill_memory_scale: 84 | self.skill_feedback.pop(0) 85 | self.sampled = [] 86 | 87 | if len(self.trajectories) < 2: 88 | return 89 | 90 | # Extract potential skills 91 | skills: List[Skill] = [] 92 | for trajectory_idx, trajectory in enumerate(self.trajectories[-self.max_trajectories:]): 93 | if trajectory in self.built_trajectories: 94 | continue 95 | self.built_trajectories.add(trajectory) 96 | if trajectory_idx == 0: 97 | continue 98 | trajectory_idx += max(0, len(self.trajectories) - self.max_trajectories) 99 | for start_idx in range(len(trajectory)): 100 | for end_idx in range(start_idx + self.min_skill_len, start_idx + self.max_skill_len): 101 | if end_idx > len(trajectory) + 1: 102 | break 103 | 104 | skill = Skill(max_trajectories=self.max_traj_count) 105 | skill.add(trajectory.slice(start_idx, end_idx), trajectory_idx, start_idx, end_idx) 106 | 107 | for i, traj1 in enumerate(self.trajectories[-self.max_trajectories:trajectory_idx]): 108 | i += max(0, len(self.trajectories) - self.max_trajectories) 109 | proposed_skill = self.add_to_skill(skill, traj1, i) 110 | if proposed_skill is not None: 111 | skills.append(proposed_skill) 112 | self.all_state_similarities.append(proposed_skill.state_similarity) 113 | self.all_action_similarities.append(proposed_skill.action_similarity) 114 | 115 | if len(skills) == 0: 116 | print("Extracted no skills.") 117 | return 118 | 119 | # Set similarity normalization 120 | self.mean_state_similarity = np.mean(self.all_state_similarities) 121 | self.std_state_similarity = np.std(self.all_state_similarities) 122 | self.mean_action_similarity = np.mean(self.all_action_similarities) 123 | self.std_action_similarity = np.std(self.all_action_similarities) 124 | skills = [x for x in skills if x.state_similarity > self.mean_state_similarity and x.action_similarity > self.mean_action_similarity] 125 | 126 | # Set reward normalization 127 | rewards = [x.reward for traj in self.trajectories[-self.max_trajectories:] for x in traj if x.reward > 0] 128 | self.mean_reward = np.mean(rewards) 129 | self.std_reward = np.std(rewards) 130 | 131 | # Sort skills 132 | self.last_skills = list(set(self.last_skills + skills)) 133 | self.last_skills = sorted(self.last_skills, key=lambda x: self.score_skill(x), reverse=True) 134 | 135 | # Only keep skills used in beam search 136 | self.memory, sampled_skills = self.beam_search(self.last_skills, return_sampled=True) 137 | self.last_skills = [skill for skill in self.last_skills if skill in sampled_skills] 138 | self.last_skills = sorted(self.last_skills, key=lambda x: sampled_skills[x], reverse=True)[:self.max_skills] 139 | 140 | def get_memories(self, trajectory: Trajectory, n: int = None) -> List[Skill]: 141 | retrieved_skills = [] 142 | 143 | unused = self.memory.copy() 144 | for state in trajectory: 145 | for skill in unused: 146 | if state.skill_target is not None and state.skill_target == skill.target: 147 | retrieved_skills.append(skill) 148 | unused.remove(skill) 149 | break 150 | 151 | unused = sorted( 152 | unused, 153 | key=lambda x: np.mean([ 154 | get_state_similarity(subtraj[0], trajectory[-1], init_state=True) 155 | for subtraj in x.trajectories 156 | ]) 157 | ) 158 | 159 | retrieved_skills += unused[-self.max_retrieval:] 160 | 161 | return list(set(retrieved_skills)) 162 | 163 | def build(self, trajectories: Union[Trajectory, List[Trajectory]] = [], resample: bool = False) -> None: 164 | super().build(trajectories) 165 | 166 | if resample: 167 | self.memory = self.beam_search(self.last_skills, return_sampled=False) 168 | 169 | print("Building selected skill set...") 170 | proposed_skills = self.memory.copy() 171 | self.memory = [] 172 | for skill in tqdm(proposed_skills): 173 | skill.build() 174 | if not self.check_similar_skills(skill): 175 | self.memory.append(skill) 176 | 177 | def log_sampled(self, step: int, skill: Skill) -> None: 178 | if not any(skill == x for _, x in self.sampled): 179 | self.sampled.append((step, skill)) 180 | 181 | @staticmethod 182 | @lru_cache(maxsize=1000) 183 | def _check_similar_skills(memory: Tuple[Skill], skill: Skill) -> bool: 184 | if len(memory) == 0: 185 | return False 186 | system_message = "You are an expert planning system capable of completing various tasks in this environment. The environment provides observations in response to your actions. You have a library of skills that you reference to execute actions. A skill is composed of a list of instructions, a list of prerequisites, and a target state." 187 | 188 | prompt = "Given the below list of existing skills and the new skill, determine which existing skills are semantically equivalent to the new skill. Output a comma delimited list of numbers that correspond to the existing skills. If no skills are equivalent output 'None'. Use the following format for your response:" 189 | prompt += "\nThe existing skills permit the agent to... The new skill will permit the agent to...\nEquivalent skills: 1, 2, 3" 190 | prompt += "\n\nExisting Skills:" 191 | for i, existing_skill in enumerate(memory): 192 | prompt += "\nSkill {} prerequisites: {}".format(i + 1, ", ".join(existing_skill.prereqs)) 193 | prompt += "\nSkill {} target: {}".format(i + 1, existing_skill.target) 194 | prompt += "\nSkill {} instructions: {}".format(i + 1, ", ".join(existing_skill.instructions)) 195 | prompt += "\n\nNew skill prerequisites: {}".format(", ".join(skill.prereqs)) 196 | prompt += "\nNew skill target: {}".format(skill.target) 197 | prompt += "\nNew skill instructions: {}".format(", ".join(skill.instructions)) 198 | 199 | messages = [dict(role="system", content=system_message), dict(role="user", content=prompt)] 200 | response = query_llm(messages, temperature=0).lower() 201 | return "none" not in response.split("skills:")[-1].split("skill:")[-1] 202 | 203 | def check_similar_skills(self, skill: Skill) -> bool: 204 | return SkillSetMemory._check_similar_skills(tuple(self.memory), skill) 205 | 206 | def _norm_score(self, score: float, mean: float, std: float) -> float: 207 | return max(0, (score - mean + 2 * std) / std) 208 | 209 | def score_skill(self, skill: Skill) -> float: 210 | score = 0 211 | if self.coverage_weight > 0: 212 | score += self.coverage_weight * self._norm_score(skill.step_count(), self.mean_coverage, self.std_coverage) 213 | if self.reward_weight > 0: 214 | score += self.reward_weight * self._norm_score(skill.reward(), self.mean_reward, self.std_reward) 215 | if self.state_weight > 0: 216 | score += self.state_weight * self._norm_score(skill.state_similarity, self.mean_state_similarity, self.std_state_similarity) 217 | if self.action_weight > 0: 218 | score += self.action_weight * self._norm_score(skill.action_similarity, self.mean_action_similarity, self.std_action_similarity) 219 | if self.sampled_weight > 0: 220 | for episode in self.skill_feedback: 221 | for skl, ret, suc in episode: 222 | if skl == skill: 223 | score += (1 if suc else -1) * self.sampled_weight * self._norm_score(ret, self.mean_reward, self.std_reward) 224 | score /= (self.coverage_weight + self.reward_weight + self.state_weight + self.action_weight + self.sampled_weight) 225 | return score 226 | 227 | def add_to_skill(self, skill: Skill, trajectory: Trajectory, traj_idx: int) -> Skill: 228 | best_skill = None 229 | for t1 in range(len(trajectory)): 230 | for t2 in range(t1 + skill.traj_len(), 231 | t1 + skill.traj_len() + 1): 232 | if t2 <= len(trajectory) and t2 - t1 >= self.min_skill_len: 233 | new_skill = deepcopy(skill) 234 | if new_skill.try_add(trajectory.slice(t1, t2), traj_idx, t1, t2): 235 | if best_skill is None or \ 236 | new_skill.state_similarity * self.state_weight + new_skill.action_similarity * self.action_weight > \ 237 | best_skill.state_similarity * self.state_weight + best_skill.action_similarity * self.action_weight: 238 | best_skill = new_skill 239 | return best_skill 240 | 241 | def beam_search(self, skills: List[Skill], return_sampled: bool = False) -> List[Skill]: 242 | beams = [dict(skillset=[], unused=skills.copy()) for _ in range(self.num_beams)] 243 | sampled_skills = dict() 244 | seen = set() 245 | while any(len(b["unused"]) > 0 for b in beams): 246 | new_beams = [dict(skillset=b["skillset"], unused=[]) for b in beams] 247 | for beam in beams: 248 | for _ in range(self.num_beams): 249 | unused = beam["unused"].copy() 250 | idx = 0 251 | while idx < len(unused): 252 | if not all(x.is_compatible(unused[idx]) for x in beam["skillset"]): 253 | unused.pop(idx) 254 | elif tuple(beam["skillset"] + [unused[idx]]) in seen: 255 | idx += 1 256 | else: 257 | new_skill = unused.pop(idx) 258 | if new_skill not in sampled_skills: 259 | sampled_skills[new_skill] = 0 260 | sampled_skills[new_skill] += 1 261 | new_beams.append(dict(skillset=beam["skillset"] + [new_skill], unused=unused)) 262 | seen.add(tuple(new_beams[-1]["skillset"])) 263 | break 264 | beams = sorted( 265 | new_beams, 266 | key=lambda x: sum(self.score_skill(skill) for skill in x["skillset"]), 267 | reverse=True 268 | )[:self.num_beams] 269 | 270 | if return_sampled: 271 | return beams[0]["skillset"], sampled_skills 272 | else: 273 | return beams[0]["skillset"] -------------------------------------------------------------------------------- /static/js/bulma-slider.js: -------------------------------------------------------------------------------- 1 | (function webpackUniversalModuleDefinition(root, factory) { 2 | if(typeof exports === 'object' && typeof module === 'object') 3 | module.exports = factory(); 4 | else if(typeof define === 'function' && define.amd) 5 | define([], factory); 6 | else if(typeof exports === 'object') 7 | exports["bulmaSlider"] = factory(); 8 | else 9 | root["bulmaSlider"] = factory(); 10 | })(typeof self !== 'undefined' ? self : this, function() { 11 | return /******/ (function(modules) { // webpackBootstrap 12 | /******/ // The module cache 13 | /******/ var installedModules = {}; 14 | /******/ 15 | /******/ // The require function 16 | /******/ function __webpack_require__(moduleId) { 17 | /******/ 18 | /******/ // Check if module is in cache 19 | /******/ if(installedModules[moduleId]) { 20 | /******/ return installedModules[moduleId].exports; 21 | /******/ } 22 | /******/ // Create a new module (and put it into the cache) 23 | /******/ var module = installedModules[moduleId] = { 24 | /******/ i: moduleId, 25 | /******/ l: false, 26 | /******/ exports: {} 27 | /******/ }; 28 | /******/ 29 | /******/ // Execute the module function 30 | /******/ modules[moduleId].call(module.exports, module, module.exports, __webpack_require__); 31 | /******/ 32 | /******/ // Flag the module as loaded 33 | /******/ module.l = true; 34 | /******/ 35 | /******/ // Return the exports of the module 36 | /******/ return module.exports; 37 | /******/ } 38 | /******/ 39 | /******/ 40 | /******/ // expose the modules object (__webpack_modules__) 41 | /******/ __webpack_require__.m = modules; 42 | /******/ 43 | /******/ // expose the module cache 44 | /******/ __webpack_require__.c = installedModules; 45 | /******/ 46 | /******/ // define getter function for harmony exports 47 | /******/ __webpack_require__.d = function(exports, name, getter) { 48 | /******/ if(!__webpack_require__.o(exports, name)) { 49 | /******/ Object.defineProperty(exports, name, { 50 | /******/ configurable: false, 51 | /******/ enumerable: true, 52 | /******/ get: getter 53 | /******/ }); 54 | /******/ } 55 | /******/ }; 56 | /******/ 57 | /******/ // getDefaultExport function for compatibility with non-harmony modules 58 | /******/ __webpack_require__.n = function(module) { 59 | /******/ var getter = module && module.__esModule ? 60 | /******/ function getDefault() { return module['default']; } : 61 | /******/ function getModuleExports() { return module; }; 62 | /******/ __webpack_require__.d(getter, 'a', getter); 63 | /******/ return getter; 64 | /******/ }; 65 | /******/ 66 | /******/ // Object.prototype.hasOwnProperty.call 67 | /******/ __webpack_require__.o = function(object, property) { return Object.prototype.hasOwnProperty.call(object, property); }; 68 | /******/ 69 | /******/ // __webpack_public_path__ 70 | /******/ __webpack_require__.p = ""; 71 | /******/ 72 | /******/ // Load entry module and return exports 73 | /******/ return __webpack_require__(__webpack_require__.s = 0); 74 | /******/ }) 75 | /************************************************************************/ 76 | /******/ ([ 77 | /* 0 */ 78 | /***/ (function(module, __webpack_exports__, __webpack_require__) { 79 | 80 | "use strict"; 81 | Object.defineProperty(__webpack_exports__, "__esModule", { value: true }); 82 | /* harmony export (binding) */ __webpack_require__.d(__webpack_exports__, "isString", function() { return isString; }); 83 | /* harmony import */ var __WEBPACK_IMPORTED_MODULE_0__events__ = __webpack_require__(1); 84 | var _extends = Object.assign || function (target) { for (var i = 1; i < arguments.length; i++) { var source = arguments[i]; for (var key in source) { if (Object.prototype.hasOwnProperty.call(source, key)) { target[key] = source[key]; } } } return target; }; 85 | 86 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }(); 87 | 88 | var _typeof = typeof Symbol === "function" && typeof Symbol.iterator === "symbol" ? function (obj) { return typeof obj; } : function (obj) { return obj && typeof Symbol === "function" && obj.constructor === Symbol && obj !== Symbol.prototype ? "symbol" : typeof obj; }; 89 | 90 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } 91 | 92 | function _possibleConstructorReturn(self, call) { if (!self) { throw new ReferenceError("this hasn't been initialised - super() hasn't been called"); } return call && (typeof call === "object" || typeof call === "function") ? call : self; } 93 | 94 | function _inherits(subClass, superClass) { if (typeof superClass !== "function" && superClass !== null) { throw new TypeError("Super expression must either be null or a function, not " + typeof superClass); } subClass.prototype = Object.create(superClass && superClass.prototype, { constructor: { value: subClass, enumerable: false, writable: true, configurable: true } }); if (superClass) Object.setPrototypeOf ? Object.setPrototypeOf(subClass, superClass) : subClass.__proto__ = superClass; } 95 | 96 | 97 | 98 | var isString = function isString(unknown) { 99 | return typeof unknown === 'string' || !!unknown && (typeof unknown === 'undefined' ? 'undefined' : _typeof(unknown)) === 'object' && Object.prototype.toString.call(unknown) === '[object String]'; 100 | }; 101 | 102 | var bulmaSlider = function (_EventEmitter) { 103 | _inherits(bulmaSlider, _EventEmitter); 104 | 105 | function bulmaSlider(selector) { 106 | var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {}; 107 | 108 | _classCallCheck(this, bulmaSlider); 109 | 110 | var _this = _possibleConstructorReturn(this, (bulmaSlider.__proto__ || Object.getPrototypeOf(bulmaSlider)).call(this)); 111 | 112 | _this.element = typeof selector === 'string' ? document.querySelector(selector) : selector; 113 | // An invalid selector or non-DOM node has been provided. 114 | if (!_this.element) { 115 | throw new Error('An invalid selector or non-DOM node has been provided.'); 116 | } 117 | 118 | _this._clickEvents = ['click']; 119 | /// Set default options and merge with instance defined 120 | _this.options = _extends({}, options); 121 | 122 | _this.onSliderInput = _this.onSliderInput.bind(_this); 123 | 124 | _this.init(); 125 | return _this; 126 | } 127 | 128 | /** 129 | * Initiate all DOM element containing selector 130 | * @method 131 | * @return {Array} Array of all slider instances 132 | */ 133 | 134 | 135 | _createClass(bulmaSlider, [{ 136 | key: 'init', 137 | 138 | 139 | /** 140 | * Initiate plugin 141 | * @method init 142 | * @return {void} 143 | */ 144 | value: function init() { 145 | this._id = 'bulmaSlider' + new Date().getTime() + Math.floor(Math.random() * Math.floor(9999)); 146 | this.output = this._findOutputForSlider(); 147 | 148 | this._bindEvents(); 149 | 150 | if (this.output) { 151 | if (this.element.classList.contains('has-output-tooltip')) { 152 | // Get new output position 153 | var newPosition = this._getSliderOutputPosition(); 154 | 155 | // Set output position 156 | this.output.style['left'] = newPosition.position; 157 | } 158 | } 159 | 160 | this.emit('bulmaslider:ready', this.element.value); 161 | } 162 | }, { 163 | key: '_findOutputForSlider', 164 | value: function _findOutputForSlider() { 165 | var _this2 = this; 166 | 167 | var result = null; 168 | var outputs = document.getElementsByTagName('output') || []; 169 | 170 | Array.from(outputs).forEach(function (output) { 171 | if (output.htmlFor == _this2.element.getAttribute('id')) { 172 | result = output; 173 | return true; 174 | } 175 | }); 176 | return result; 177 | } 178 | }, { 179 | key: '_getSliderOutputPosition', 180 | value: function _getSliderOutputPosition() { 181 | // Update output position 182 | var newPlace, minValue; 183 | 184 | var style = window.getComputedStyle(this.element, null); 185 | // Measure width of range input 186 | var sliderWidth = parseInt(style.getPropertyValue('width'), 10); 187 | 188 | // Figure out placement percentage between left and right of input 189 | if (!this.element.getAttribute('min')) { 190 | minValue = 0; 191 | } else { 192 | minValue = this.element.getAttribute('min'); 193 | } 194 | var newPoint = (this.element.value - minValue) / (this.element.getAttribute('max') - minValue); 195 | 196 | // Prevent bubble from going beyond left or right (unsupported browsers) 197 | if (newPoint < 0) { 198 | newPlace = 0; 199 | } else if (newPoint > 1) { 200 | newPlace = sliderWidth; 201 | } else { 202 | newPlace = sliderWidth * newPoint; 203 | } 204 | 205 | return { 206 | 'position': newPlace + 'px' 207 | }; 208 | } 209 | 210 | /** 211 | * Bind all events 212 | * @method _bindEvents 213 | * @return {void} 214 | */ 215 | 216 | }, { 217 | key: '_bindEvents', 218 | value: function _bindEvents() { 219 | if (this.output) { 220 | // Add event listener to update output when slider value change 221 | this.element.addEventListener('input', this.onSliderInput, false); 222 | } 223 | } 224 | }, { 225 | key: 'onSliderInput', 226 | value: function onSliderInput(e) { 227 | e.preventDefault(); 228 | 229 | if (this.element.classList.contains('has-output-tooltip')) { 230 | // Get new output position 231 | var newPosition = this._getSliderOutputPosition(); 232 | 233 | // Set output position 234 | this.output.style['left'] = newPosition.position; 235 | } 236 | 237 | // Check for prefix and postfix 238 | var prefix = this.output.hasAttribute('data-prefix') ? this.output.getAttribute('data-prefix') : ''; 239 | var postfix = this.output.hasAttribute('data-postfix') ? this.output.getAttribute('data-postfix') : ''; 240 | 241 | // Update output with slider value 242 | this.output.value = prefix + this.element.value + postfix; 243 | 244 | this.emit('bulmaslider:ready', this.element.value); 245 | } 246 | }], [{ 247 | key: 'attach', 248 | value: function attach() { 249 | var _this3 = this; 250 | 251 | var selector = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : 'input[type="range"].slider'; 252 | var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {}; 253 | 254 | var instances = new Array(); 255 | 256 | var elements = isString(selector) ? document.querySelectorAll(selector) : Array.isArray(selector) ? selector : [selector]; 257 | elements.forEach(function (element) { 258 | if (typeof element[_this3.constructor.name] === 'undefined') { 259 | var instance = new bulmaSlider(element, options); 260 | element[_this3.constructor.name] = instance; 261 | instances.push(instance); 262 | } else { 263 | instances.push(element[_this3.constructor.name]); 264 | } 265 | }); 266 | 267 | return instances; 268 | } 269 | }]); 270 | 271 | return bulmaSlider; 272 | }(__WEBPACK_IMPORTED_MODULE_0__events__["a" /* default */]); 273 | 274 | /* harmony default export */ __webpack_exports__["default"] = (bulmaSlider); 275 | 276 | /***/ }), 277 | /* 1 */ 278 | /***/ (function(module, __webpack_exports__, __webpack_require__) { 279 | 280 | "use strict"; 281 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }(); 282 | 283 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } 284 | 285 | var EventEmitter = function () { 286 | function EventEmitter() { 287 | var listeners = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : []; 288 | 289 | _classCallCheck(this, EventEmitter); 290 | 291 | this._listeners = new Map(listeners); 292 | this._middlewares = new Map(); 293 | } 294 | 295 | _createClass(EventEmitter, [{ 296 | key: "listenerCount", 297 | value: function listenerCount(eventName) { 298 | if (!this._listeners.has(eventName)) { 299 | return 0; 300 | } 301 | 302 | var eventListeners = this._listeners.get(eventName); 303 | return eventListeners.length; 304 | } 305 | }, { 306 | key: "removeListeners", 307 | value: function removeListeners() { 308 | var _this = this; 309 | 310 | var eventName = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; 311 | var middleware = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false; 312 | 313 | if (eventName !== null) { 314 | if (Array.isArray(eventName)) { 315 | name.forEach(function (e) { 316 | return _this.removeListeners(e, middleware); 317 | }); 318 | } else { 319 | this._listeners.delete(eventName); 320 | 321 | if (middleware) { 322 | this.removeMiddleware(eventName); 323 | } 324 | } 325 | } else { 326 | this._listeners = new Map(); 327 | } 328 | } 329 | }, { 330 | key: "middleware", 331 | value: function middleware(eventName, fn) { 332 | var _this2 = this; 333 | 334 | if (Array.isArray(eventName)) { 335 | name.forEach(function (e) { 336 | return _this2.middleware(e, fn); 337 | }); 338 | } else { 339 | if (!Array.isArray(this._middlewares.get(eventName))) { 340 | this._middlewares.set(eventName, []); 341 | } 342 | 343 | this._middlewares.get(eventName).push(fn); 344 | } 345 | } 346 | }, { 347 | key: "removeMiddleware", 348 | value: function removeMiddleware() { 349 | var _this3 = this; 350 | 351 | var eventName = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; 352 | 353 | if (eventName !== null) { 354 | if (Array.isArray(eventName)) { 355 | name.forEach(function (e) { 356 | return _this3.removeMiddleware(e); 357 | }); 358 | } else { 359 | this._middlewares.delete(eventName); 360 | } 361 | } else { 362 | this._middlewares = new Map(); 363 | } 364 | } 365 | }, { 366 | key: "on", 367 | value: function on(name, callback) { 368 | var _this4 = this; 369 | 370 | var once = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false; 371 | 372 | if (Array.isArray(name)) { 373 | name.forEach(function (e) { 374 | return _this4.on(e, callback); 375 | }); 376 | } else { 377 | name = name.toString(); 378 | var split = name.split(/,|, | /); 379 | 380 | if (split.length > 1) { 381 | split.forEach(function (e) { 382 | return _this4.on(e, callback); 383 | }); 384 | } else { 385 | if (!Array.isArray(this._listeners.get(name))) { 386 | this._listeners.set(name, []); 387 | } 388 | 389 | this._listeners.get(name).push({ once: once, callback: callback }); 390 | } 391 | } 392 | } 393 | }, { 394 | key: "once", 395 | value: function once(name, callback) { 396 | this.on(name, callback, true); 397 | } 398 | }, { 399 | key: "emit", 400 | value: function emit(name, data) { 401 | var _this5 = this; 402 | 403 | var silent = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false; 404 | 405 | name = name.toString(); 406 | var listeners = this._listeners.get(name); 407 | var middlewares = null; 408 | var doneCount = 0; 409 | var execute = silent; 410 | 411 | if (Array.isArray(listeners)) { 412 | listeners.forEach(function (listener, index) { 413 | // Start Middleware checks unless we're doing a silent emit 414 | if (!silent) { 415 | middlewares = _this5._middlewares.get(name); 416 | // Check and execute Middleware 417 | if (Array.isArray(middlewares)) { 418 | middlewares.forEach(function (middleware) { 419 | middleware(data, function () { 420 | var newData = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; 421 | 422 | if (newData !== null) { 423 | data = newData; 424 | } 425 | doneCount++; 426 | }, name); 427 | }); 428 | 429 | if (doneCount >= middlewares.length) { 430 | execute = true; 431 | } 432 | } else { 433 | execute = true; 434 | } 435 | } 436 | 437 | // If Middleware checks have been passed, execute 438 | if (execute) { 439 | if (listener.once) { 440 | listeners[index] = null; 441 | } 442 | listener.callback(data); 443 | } 444 | }); 445 | 446 | // Dirty way of removing used Events 447 | while (listeners.indexOf(null) !== -1) { 448 | listeners.splice(listeners.indexOf(null), 1); 449 | } 450 | } 451 | } 452 | }]); 453 | 454 | return EventEmitter; 455 | }(); 456 | 457 | /* harmony default export */ __webpack_exports__["a"] = (EventEmitter); 458 | 459 | /***/ }) 460 | /******/ ])["default"]; 461 | }); -------------------------------------------------------------------------------- /static/js/bulma-carousel.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaCarousel=e():t.bulmaCarousel=e()}("undefined"!=typeof self?self:this,function(){return function(i){var n={};function s(t){if(n[t])return n[t].exports;var e=n[t]={i:t,l:!1,exports:{}};return i[t].call(e.exports,e,e.exports,s),e.l=!0,e.exports}return s.m=i,s.c=n,s.d=function(t,e,i){s.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:i})},s.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return s.d(e,"a",e),e},s.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},s.p="",s(s.s=5)}([function(t,e,i){"use strict";i.d(e,"d",function(){return s}),i.d(e,"e",function(){return r}),i.d(e,"b",function(){return o}),i.d(e,"c",function(){return a}),i.d(e,"a",function(){return l});var n=i(2),s=function(e,t){(t=Array.isArray(t)?t:t.split(" ")).forEach(function(t){e.classList.remove(t)})},r=function(t){return t.getBoundingClientRect().width||t.offsetWidth},o=function(t){return t.getBoundingClientRect().height||t.offsetHeight},a=function(t){var e=1=t._x&&this._x<=e._x&&this._y>=t._y&&this._y<=e._y}},{key:"constrain",value:function(t,e){if(t._x>e._x||t._y>e._y)return this;var i=this._x,n=this._y;return null!==t._x&&(i=Math.max(i,t._x)),null!==e._x&&(i=Math.min(i,e._x)),null!==t._y&&(n=Math.max(n,t._y)),null!==e._y&&(n=Math.min(n,e._y)),new s(i,n)}},{key:"reposition",value:function(t){t.style.top=this._y+"px",t.style.left=this._x+"px"}},{key:"toString",value:function(){return"("+this._x+","+this._y+")"}},{key:"x",get:function(){return this._x},set:function(){var t=0this.state.length-this.slidesToShow&&!this.options.centerMode?this.state.next=this.state.index:this.state.next=this.state.index+this.slidesToScroll,this.show()}},{key:"previous",value:function(){this.options.loop||this.options.infinite||0!==this.state.index?this.state.next=this.state.index-this.slidesToScroll:this.state.next=this.state.index,this.show()}},{key:"start",value:function(){this._autoplay.start()}},{key:"pause",value:function(){this._autoplay.pause()}},{key:"stop",value:function(){this._autoplay.stop()}},{key:"show",value:function(t){var e=1this.options.slidesToShow&&(this.options.slidesToScroll=this.slidesToShow),this._breakpoint.init(),this.state.index>=this.state.length&&0!==this.state.index&&(this.state.index=this.state.index-this.slidesToScroll),this.state.length<=this.slidesToShow&&(this.state.index=0),this._ui.wrapper.appendChild(this._navigation.init().render()),this._ui.wrapper.appendChild(this._pagination.init().render()),this.options.navigationSwipe?this._swipe.bindEvents():this._swipe._bindEvents(),this._breakpoint.apply(),this._slides.forEach(function(t){return e._ui.container.appendChild(t)}),this._transitioner.init().apply(!0,this._setHeight.bind(this)),this.options.autoplay&&this._autoplay.init().start()}},{key:"destroy",value:function(){var e=this;this._unbindEvents(),this._items.forEach(function(t){e.element.appendChild(t)}),this.node.remove()}},{key:"id",get:function(){return this._id}},{key:"index",set:function(t){this._index=t},get:function(){return this._index}},{key:"length",set:function(t){this._length=t},get:function(){return this._length}},{key:"slides",get:function(){return this._slides},set:function(t){this._slides=t}},{key:"slidesToScroll",get:function(){return"translate"===this.options.effect?this._breakpoint.getSlidesToScroll():1}},{key:"slidesToShow",get:function(){return"translate"===this.options.effect?this._breakpoint.getSlidesToShow():1}},{key:"direction",get:function(){return"rtl"===this.element.dir.toLowerCase()||"rtl"===this.element.style.direction?"rtl":"ltr"}},{key:"wrapper",get:function(){return this._ui.wrapper}},{key:"wrapperWidth",get:function(){return this._wrapperWidth||0}},{key:"container",get:function(){return this._ui.container}},{key:"containerWidth",get:function(){return this._containerWidth||0}},{key:"slideWidth",get:function(){return this._slideWidth||0}},{key:"transitioner",get:function(){return this._transitioner}}],[{key:"attach",value:function(){var i=this,t=0>t/4).toString(16)})}},function(t,e,i){"use strict";var n=i(3),s=i(8),r=function(){function n(t,e){for(var i=0;i=t.slider.state.length-t.slider.slidesToShow&&!t.slider.options.loop&&!t.slider.options.infinite?t.stop():t.slider.next())},this.slider.options.autoplaySpeed))}},{key:"stop",value:function(){this._interval=clearInterval(this._interval),this.emit("stop",this)}},{key:"pause",value:function(){var t=this,e=0parseInt(e.changePoint,10)}),this._currentBreakpoint=this._getActiveBreakpoint(),this}},{key:"destroy",value:function(){this._unbindEvents()}},{key:"_bindEvents",value:function(){window.addEventListener("resize",this[s]),window.addEventListener("orientationchange",this[s])}},{key:"_unbindEvents",value:function(){window.removeEventListener("resize",this[s]),window.removeEventListener("orientationchange",this[s])}},{key:"_getActiveBreakpoint",value:function(){var t=!0,e=!1,i=void 0;try{for(var n,s=this.options.breakpoints[Symbol.iterator]();!(t=(n=s.next()).done);t=!0){var r=n.value;if(r.changePoint>=window.innerWidth)return r}}catch(t){e=!0,i=t}finally{try{!t&&s.return&&s.return()}finally{if(e)throw i}}return this._defaultBreakpoint}},{key:"getSlidesToShow",value:function(){return this._currentBreakpoint?this._currentBreakpoint.slidesToShow:this._defaultBreakpoint.slidesToShow}},{key:"getSlidesToScroll",value:function(){return this._currentBreakpoint?this._currentBreakpoint.slidesToScroll:this._defaultBreakpoint.slidesToScroll}},{key:"apply",value:function(){this.slider.state.index>=this.slider.state.length&&0!==this.slider.state.index&&(this.slider.state.index=this.slider.state.index-this._currentBreakpoint.slidesToScroll),this.slider.state.length<=this._currentBreakpoint.slidesToShow&&(this.slider.state.index=0),this.options.loop&&this.slider._loop.init().apply(),this.options.infinite&&this.slider._infinite.init().apply(),this.slider._setDimensions(),this.slider._transitioner.init().apply(!0,this.slider._setHeight.bind(this.slider)),this.slider._setClasses(),this.slider._navigation.refresh(),this.slider._pagination.refresh()}},{key:s,value:function(t){var e=this._getActiveBreakpoint();e.slidesToShow!==this._currentBreakpoint.slidesToShow&&(this._currentBreakpoint=e,this.apply())}}]),e}();e.a=r},function(t,e,i){"use strict";var n=function(){function n(t,e){for(var i=0;ithis.slider.state.length-1-this._infiniteCount;i-=1)e=i-1,t.unshift(this._cloneSlide(this.slider.slides[e],e-this.slider.state.length));for(var n=[],s=0;s=this.slider.state.length?(this.slider.state.index=this.slider.state.next=this.slider.state.next-this.slider.state.length,this.slider.transitioner.apply(!0)):this.slider.state.next<0&&(this.slider.state.index=this.slider.state.next=this.slider.state.length+this.slider.state.next,this.slider.transitioner.apply(!0)))}},{key:"_cloneSlide",value:function(t,e){var i=t.cloneNode(!0);return i.dataset.sliderIndex=e,i.dataset.cloned=!0,(i.querySelectorAll("[id]")||[]).forEach(function(t){t.setAttribute("id","")}),i}}]),e}();e.a=s},function(t,e,i){"use strict";var n=i(12),s=function(){function n(t,e){for(var i=0;ithis.slider.state.length-this.slider.slidesToShow&&Object(n.a)(this.slider._slides[this.slider.state.length-1],this.slider.wrapper)?this.slider.state.next=0:this.slider.state.next=Math.min(Math.max(this.slider.state.next,0),this.slider.state.length-this.slider.slidesToShow):this.slider.state.next=0:this.slider.state.next<=0-this.slider.slidesToScroll?this.slider.state.next=this.slider.state.length-this.slider.slidesToShow:this.slider.state.next=0)}}]),e}();e.a=r},function(t,e,i){"use strict";i.d(e,"a",function(){return n});var n=function(t,e){var i=t.getBoundingClientRect();return e=e||document.documentElement,0<=i.top&&0<=i.left&&i.bottom<=(window.innerHeight||e.clientHeight)&&i.right<=(window.innerWidth||e.clientWidth)}},function(t,e,i){"use strict";var n=i(14),s=i(1),r=function(){function n(t,e){for(var i=0;ithis.slider.slidesToShow?(this._ui.previous.classList.remove("is-hidden"),this._ui.next.classList.remove("is-hidden"),0===this.slider.state.next?(this._ui.previous.classList.add("is-hidden"),this._ui.next.classList.remove("is-hidden")):this.slider.state.next>=this.slider.state.length-this.slider.slidesToShow&&!this.slider.options.centerMode?(this._ui.previous.classList.remove("is-hidden"),this._ui.next.classList.add("is-hidden")):this.slider.state.next>=this.slider.state.length-1&&this.slider.options.centerMode&&(this._ui.previous.classList.remove("is-hidden"),this._ui.next.classList.add("is-hidden"))):(this._ui.previous.classList.add("is-hidden"),this._ui.next.classList.add("is-hidden")))}},{key:"render",value:function(){return this.node}}]),e}();e.a=o},function(t,e,i){"use strict";e.a=function(t){return'
'+t.previous+'
\n
'+t.next+"
"}},function(t,e,i){"use strict";var n=i(16),s=i(17),r=i(1),o=function(){function n(t,e){for(var i=0;ithis.slider.slidesToShow){for(var t=0;t<=this._count;t++){var e=document.createRange().createContextualFragment(Object(s.a)()).firstChild;e.dataset.index=t*this.slider.slidesToScroll,this._pages.push(e),this._ui.container.appendChild(e)}this._bindEvents()}}},{key:"onPageClick",value:function(t){this._supportsPassive||t.preventDefault(),this.slider.state.next=t.currentTarget.dataset.index,this.slider.show()}},{key:"onResize",value:function(){this._draw()}},{key:"refresh",value:function(){var e=this,t=void 0;(t=this.slider.options.infinite?Math.ceil(this.slider.state.length-1/this.slider.slidesToScroll):Math.ceil((this.slider.state.length-this.slider.slidesToShow)/this.slider.slidesToScroll))!==this._count&&(this._count=t,this._draw()),this._pages.forEach(function(t){t.classList.remove("is-active"),parseInt(t.dataset.index,10)===e.slider.state.next%e.slider.state.length&&t.classList.add("is-active")})}},{key:"render",value:function(){return this.node}}]),e}();e.a=a},function(t,e,i){"use strict";e.a=function(){return'
'}},function(t,e,i){"use strict";e.a=function(){return'
'}},function(t,e,i){"use strict";var n=i(4),s=i(1),r=function(){function n(t,e){for(var i=0;iMath.abs(this._lastTranslate.y)&&(this._supportsPassive||t.preventDefault(),t.stopPropagation())}}},{key:"onStopDrag",value:function(t){this._origin&&this._lastTranslate&&(Math.abs(this._lastTranslate.x)>.2*this.width?this._lastTranslate.x<0?this.slider.next():this.slider.previous():this.slider.show(!0)),this._origin=null,this._lastTranslate=null}}]),e}();e.a=o},function(t,e,i){"use strict";var n=i(20),s=i(21),r=function(){function n(t,e){for(var i=0;it.x?(s.x=0,this.slider.state.next=0):this.options.vertical&&Math.abs(this._position.y)>t.y&&(s.y=0,this.slider.state.next=0)),this._position.x=s.x,this._position.y=s.y,this.options.centerMode&&(this._position.x=this._position.x+this.slider.wrapperWidth/2-Object(o.e)(i)/2),"rtl"===this.slider.direction&&(this._position.x=-this._position.x,this._position.y=-this._position.y),this.slider.container.style.transform="translate3d("+this._position.x+"px, "+this._position.y+"px, 0)",n.x>t.x&&this.slider.transitioner.end()}}},{key:"onTransitionEnd",value:function(t){"translate"===this.options.effect&&(this.transitioner.isAnimating()&&t.target==this.slider.container&&this.options.infinite&&this.slider._infinite.onTransitionEnd(t),this.transitioner.end())}}]),n}();e.a=n},function(t,e,i){"use strict";e.a={initialSlide:0,slidesToScroll:1,slidesToShow:1,navigation:!0,navigationKeys:!0,navigationSwipe:!0,pagination:!0,loop:!1,infinite:!1,effect:"translate",duration:300,timing:"ease",autoplay:!1,autoplaySpeed:3e3,pauseOnHover:!0,breakpoints:[{changePoint:480,slidesToShow:1,slidesToScroll:1},{changePoint:640,slidesToShow:2,slidesToScroll:2},{changePoint:768,slidesToShow:3,slidesToScroll:3}],onReady:null,icons:{previous:'\n \n ',next:'\n \n '}}},function(t,e,i){"use strict";e.a=function(t){return'
\n
\n
'}},function(t,e,i){"use strict";e.a=function(){return'
'}}]).default}); --------------------------------------------------------------------------------