├── sso
├── __init__.py
├── env
│ ├── nethack
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ ├── base.py
│ │ └── maps.py
│ ├── __init__.py
│ └── scienceworld.py
├── memory
│ ├── examples.py
│ ├── __init__.py
│ └── skillset.py
├── llm
│ ├── __init__.py
│ └── gpt.py
├── agent
│ ├── __init__.py
│ ├── fewshot.py
│ ├── reflexion.py
│ └── skills.py
├── utils.py
├── trajectory.py
└── skill.py
├── .gitignore
├── static
├── images
│ ├── sso.png
│ ├── uci.png
│ ├── usage.png
│ ├── incontext.png
│ ├── scienceworld.png
│ ├── sso_example.png
│ ├── ai2-logo-header.png
│ ├── ai2_website_top.png
│ └── aristo-logo-header.png
├── js
│ ├── index.js
│ ├── bulma-slider.min.js
│ ├── bulma-slider.js
│ └── bulma-carousel.min.js
└── css
│ ├── index.css
│ ├── bulma-carousel.min.css
│ └── bulma-slider.min.css
├── requirements.txt
├── README.md
├── main.py
├── LICENSE
└── index.html
/sso/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/sso/env/nethack/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | venv/
2 | results/
3 | .vscode/
4 | __pycache__/
5 |
--------------------------------------------------------------------------------
/static/images/sso.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/sso/HEAD/static/images/sso.png
--------------------------------------------------------------------------------
/static/images/uci.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/sso/HEAD/static/images/uci.png
--------------------------------------------------------------------------------
/static/images/usage.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/sso/HEAD/static/images/usage.png
--------------------------------------------------------------------------------
/static/images/incontext.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/sso/HEAD/static/images/incontext.png
--------------------------------------------------------------------------------
/static/images/scienceworld.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/sso/HEAD/static/images/scienceworld.png
--------------------------------------------------------------------------------
/static/images/sso_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/sso/HEAD/static/images/sso_example.png
--------------------------------------------------------------------------------
/static/images/ai2-logo-header.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/sso/HEAD/static/images/ai2-logo-header.png
--------------------------------------------------------------------------------
/static/images/ai2_website_top.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/sso/HEAD/static/images/ai2_website_top.png
--------------------------------------------------------------------------------
/static/images/aristo-logo-header.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/sso/HEAD/static/images/aristo-logo-header.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | openai==0.28.0
2 | git+https://github.com/allenai/ScienceWorld.git@exhaustivevalidactions
3 | matplotlib==3.5.3
4 | networkx==2.6.3
5 | minihack==0.1.4
6 | nle-language-wrapper==0.2.0
--------------------------------------------------------------------------------
/sso/memory/examples.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from sso.trajectory import Trajectory
4 | from sso.memory import Memory
5 |
6 |
7 | class ExamplesMemory(Memory):
8 |
9 | def __init__(self, **kwargs):
10 | super().__init__(**kwargs)
11 | self.best_reward = float("-inf")
12 |
13 | def insert(self, trajectory: Trajectory) -> None:
14 | reward = sum(x.reward for x in trajectory if x.reward is not None)
15 | if reward >= self.best_reward:
16 | self.trajectories.append(trajectory)
17 | self.best_reward = max(reward, self.best_reward)
18 | best_trajectories = []
19 | for traj in self.trajectories:
20 | if sum(x.reward for x in traj if x.reward is not None) == self.best_reward:
21 | best_trajectories.append(traj)
22 | self.trajectories = best_trajectories
23 |
24 | def get_memories(self, trajectory: Trajectory = None, n: int = None) -> List[Trajectory]:
25 | if n is None:
26 | n = len(self.trajectories)
27 | return self.trajectories[-n:]
28 |
--------------------------------------------------------------------------------
/sso/llm/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict
2 |
3 | from sso.llm.gpt import get_response as get_response_gpt, get_embedding as get_embedding_gpt
4 |
5 |
6 | global DEFAULT_MODEL
7 | DEFAULT_MODEL = None
8 |
9 | global DEFAULT_TEMP
10 | DEFAULT_TEMP = None
11 |
12 | global DEFAULT_EMBEDDING
13 | DEFAULT_EMBEDDING = None
14 |
15 |
16 | def set_default_model(model: str = None, temp: float = None, embedding: str = None) -> None:
17 | if model is not None:
18 | global DEFAULT_MODEL
19 | DEFAULT_MODEL = model
20 |
21 | if temp is not None:
22 | global DEFAULT_TEMP
23 | DEFAULT_TEMP = temp
24 |
25 | if embedding is not None:
26 | global DEFAULT_EMBEDDING
27 | DEFAULT_EMBEDDING = embedding
28 |
29 |
30 | def query_llm(messages: List[Dict[str, str]], model: str = None, temperature: float = 1, **generation_kwargs) -> str:
31 | if model is None:
32 | global DEFAULT_MODEL
33 | model = DEFAULT_MODEL
34 | if temperature is None:
35 | global DEFAULT_TEMP
36 | temperature = DEFAULT_TEMP
37 | if model.startswith("gpt"):
38 | return get_response_gpt(model, messages, temperature=temperature, **generation_kwargs)
39 |
40 |
41 | def get_embedding(content: str, model: str = None) -> List[float]:
42 | if model is None:
43 | global DEFAULT_EMBEDDING
44 | model = DEFAULT_EMBEDDING
45 |
46 | if model.startswith("text"):
47 | return get_embedding_gpt(model, content)
48 |
--------------------------------------------------------------------------------
/sso/env/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Dict, Any
2 | from abc import ABC, abstractmethod, abstractproperty
3 |
4 | from sso.trajectory import State
5 |
6 |
7 | class Env(ABC):
8 |
9 | @abstractproperty
10 | def max_steps(self) -> int:
11 | """
12 | :return: maximum number of steps in the simulation
13 | """
14 | raise NotImplementedError
15 |
16 | @abstractproperty
17 | def num_train(self) -> int:
18 | """
19 | :return: number of training variants
20 | """
21 | raise NotImplementedError
22 |
23 | @abstractproperty
24 | def num_test(self) -> int:
25 | """
26 | :return: number of test variants
27 | """
28 | raise NotImplementedError
29 |
30 | @abstractproperty
31 | def train_ids(self) -> Tuple[str]:
32 | """
33 | :return: training task IDs
34 | """
35 | raise NotImplementedError
36 |
37 | @abstractproperty
38 | def test_ids(self) -> Tuple[str]:
39 | """
40 | :return: test task IDs
41 | """
42 | raise NotImplementedError
43 |
44 | @abstractmethod
45 | def reset(self, task_id: str = None, test: bool = False, **kwargs) -> Tuple[State, Dict[str, Any]]:
46 | """
47 | Reset the simulation.
48 |
49 | :param test: whether to use a test task
50 | :param kwargs: additional arguments
51 | :return: initial state and info
52 | """
53 | raise NotImplementedError
54 |
55 | @abstractmethod
56 | def step(self, action: str) -> Tuple[State, float, bool, Dict[str, Any]]:
57 | """
58 | Perform an action in the simulation.
59 |
60 | :param action: action to perform
61 | :return: next state, reward, done, info
62 | """
63 | raise NotImplementedError
64 |
--------------------------------------------------------------------------------
/sso/llm/gpt.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict
2 | import os
3 | import time
4 | from functools import lru_cache
5 | import openai
6 | openai.api_key = os.environ["OPENAI_API_KEY"]
7 |
8 |
9 | def get_response(model: str, messages: List[Dict[str, str]], max_tries=50, temperature=1, **kwargs) -> str:
10 | completion = None
11 | num_tries = 0
12 | while not completion and num_tries < max_tries:
13 | try:
14 | completion = openai.ChatCompletion.create(
15 | model=model,
16 | messages=messages,
17 | **kwargs
18 | ).choices[0].message.content
19 | break
20 | except Exception as e:
21 | num_tries += 1
22 | print("try {}: {}".format(num_tries, e))
23 | if "maximum context length" in str(e):
24 | if len(messages) > 3:
25 | if messages[0]["role"] == "system":
26 | messages = [messages[0]] + messages[3:]
27 | else:
28 | messages = messages[2:]
29 | else:
30 | raise RuntimeError("messages too long")
31 | time.sleep(2)
32 | if not completion:
33 | raise RuntimeError("Failed to get response from API")
34 | return completion
35 |
36 |
37 | @lru_cache(maxsize=1000)
38 | def get_embedding(model: str, content: str, max_tries=50) -> List[float]:
39 | embedding = None
40 | num_tries = 0
41 | while not embedding and num_tries < max_tries:
42 | try:
43 | embedding = openai.Embedding.create(model=model, input=content).data[0].embedding
44 | break
45 | except Exception as e:
46 | num_tries += 1
47 | print("try {}: {}".format(num_tries, e))
48 | time.sleep(2)
49 | if not embedding:
50 | raise RuntimeError("Failed to get embedding response from API")
51 | return embedding
52 |
--------------------------------------------------------------------------------
/sso/agent/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List, Dict
2 |
3 | from sso.trajectory import Trajectory
4 | from sso.memory import Memory
5 |
6 |
7 | class Agent:
8 |
9 | def __init__(self, memory: Memory = None, max_history: int = 10, trim_states: bool = True):
10 | assert memory is not None
11 | self.memory = memory
12 | self._log = []
13 | self.freeze_memory = False
14 | self.max_history = max_history
15 | self.trim_states = trim_states
16 |
17 | def train(self, train: bool = True) -> None:
18 | self.freeze_memory = not train
19 |
20 | def clear(self) -> None:
21 | self.memory.clear()
22 |
23 | def act(self, trajectory: Trajectory) -> str:
24 | self.step_log()
25 | self.log("reward", trajectory[-1].reward)
26 | self.log("last_action", trajectory[-1].last_action)
27 | self.log("state_description", trajectory[-1].state_description)
28 |
29 | # Get next action
30 | return None
31 |
32 | def _update_memory(self, trajectory: Trajectory) -> None:
33 | raise NotImplementedError
34 |
35 | def record_done(self, trajectory: Trajectory) -> None:
36 | self.step_log()
37 | self.log("reward", trajectory[-1].reward)
38 | self.log("last_action", trajectory[-1].last_action)
39 | self.log("state_description", trajectory[-1].state_description)
40 | if not self.freeze_memory:
41 | self._update_memory(trajectory)
42 |
43 | def save(self, save_dir: str) -> None:
44 | self.memory.save(save_dir)
45 |
46 | def load(self, load_dir: str, rebuild: bool = False) -> None:
47 | self.memory.load(load_dir, rebuild=rebuild)
48 |
49 | def log(self, key: str, value: Any) -> None:
50 | if len(self._log) == 0:
51 | self.step_log()
52 | if key in self._log[-1]:
53 | if isinstance(self._log[-1][key], list):
54 | self._log[-1][key].append(value)
55 | else:
56 | self._log[-1][key] = [self._log[-1][key], value]
57 | else:
58 | self._log[-1][key] = value
59 |
60 | def step_log(self) -> None:
61 | self._log.append(dict())
62 |
63 | def reset_log(self) -> None:
64 | self._log = []
65 |
66 | def get_log(self) -> List[Dict[str, Any]]:
67 | return self._log
68 |
--------------------------------------------------------------------------------
/sso/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import List, TYPE_CHECKING
3 | if TYPE_CHECKING:
4 | from sso.trajectory import State
5 |
6 | from functools import lru_cache
7 | import numpy as np
8 |
9 | from sso.llm import get_embedding
10 |
11 |
12 | @lru_cache(maxsize=100000)
13 | def clean_feature(feature: str, remove_fill_words=False) -> str:
14 | words_to_skip = set(["to", "the", "for", "on", "in", "a", "an", ""]) if remove_fill_words else set([""])
15 | words = []
16 | for word in feature.lower().split():
17 | word = word.strip(".,!?\"')(}{][:; \t\n")
18 | if word not in words_to_skip:
19 | words.append(word)
20 | return " ".join(words)
21 |
22 |
23 | @lru_cache(maxsize=100000)
24 | def _get_emedding_similarity(text1: str, text2: str) -> float:
25 | if clean_feature(text1, remove_fill_words=True) == clean_feature(text2, remove_fill_words=True):
26 | return 1.0
27 | embedding1 = np.array(get_embedding(text1))
28 | embedding2 = np.array(get_embedding(text2))
29 | return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
30 |
31 |
32 | def get_similarity(text1: str, text2: str) -> float:
33 | return _get_emedding_similarity(text1, text2)
34 |
35 |
36 | def get_feature_similarity(feature: str, features: List[str]) -> float:
37 | if isinstance(features, str):
38 | features = [features]
39 | return max(get_similarity(feature, feature2) for feature2 in features)
40 |
41 |
42 | def get_similar_action(action: str, actions: List[str]) -> str:
43 | for x in actions:
44 | if clean_feature(action, remove_fill_words=True) == clean_feature(x, remove_fill_words=True):
45 | return x
46 | return None
47 |
48 |
49 | def get_state_similarity(state1: State, state2: State, init_state: bool = False) -> float:
50 | if state1.last_action is None and state2.last_action is None or init_state:
51 | return _get_emedding_similarity(state1.state_description, state2.state_description)
52 | elif state1.last_action is not None and state2.last_action is not None:
53 | return _get_emedding_similarity(
54 | "You chose to {}.\n\n".format(state1.last_action) + state1.state_description,
55 | "You chose to {}.\n\n".format(state2.last_action) + state2.state_description
56 | )
57 | else:
58 | return 0
59 |
--------------------------------------------------------------------------------
/sso/env/nethack/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Union, List
2 | from nle_language_wrapper.nle_language_obsv import NLELanguageObsv
3 |
4 |
5 | NLE_LANG = NLELanguageObsv()
6 |
7 |
8 | ACTION_TEMPLATES = [
9 | "k: move north",
10 | "l: move east",
11 | "j: move south",
12 | "h: move west",
13 | "y: move northwest",
14 | "u: move northeast",
15 | "b: move southwest",
16 | "n: move southeast",
17 | ",: pick up at current location",
18 | "d: drop item",
19 | "o: open door",
20 | "t: throw item",
21 | "e: eat food",
22 | "w: weild weapon",
23 | "W: wear armor",
24 | "T: take off armor",
25 | "P: put on accessory",
26 | "R: remove accessory",
27 | "q: quaff/drink potion",
28 | "a: apply/use item",
29 | "z: zap wand"
30 | ]
31 |
32 |
33 | def get_message(obs) -> str:
34 | message = NLE_LANG.text_message(obs["tty_chars"]).decode("latin-1")
35 | if message is None:
36 | message = ""
37 | return message.replace("\n", "; ")
38 |
39 |
40 | def get_vision(obs) -> str:
41 | vision = NLE_LANG.text_glyphs(obs["glyphs"], obs["blstats"]).decode("latin-1")
42 | for dir1 in ["east", "west", "north", "south"]:
43 | for dir2 in ["northwest", "northeast", "southwest", "southeast"]:
44 | vision = vision.replace(dir1 + dir2, dir2)
45 | return vision
46 |
47 |
48 | def get_lang_obs(obs: Dict, as_list: bool = False, use_stats: bool = False) -> Union[str, List[str]]:
49 | text_fields = {
50 | "text_glyphs": get_vision(obs),
51 | "text_message": get_message(obs),
52 | "text_blstats": NLE_LANG.text_blstats(obs["blstats"]).decode("latin-1"),
53 | "text_inventory": NLE_LANG.text_inventory(obs["inv_strs"], obs["inv_letters"]).decode("latin-1"),
54 | "text_cursor": NLE_LANG.text_cursor(obs["glyphs"], obs["blstats"], obs["tty_cursor"]).decode("latin-1"),
55 | }
56 |
57 | lang_obs = ["You have " + x[3:] for x in text_fields["text_inventory"].split("\n") if x]
58 | if use_stats:
59 | lang_obs += [x for x in text_fields["text_blstats"].split("\n") if x]
60 | lang_obs += [text_fields["text_cursor"]]
61 | lang_obs += ["You see a " + x for x in text_fields["text_glyphs"].split("\n") if x]
62 | if text_fields["text_message"]:
63 | lang_obs += [text_fields["text_message"]]
64 | if as_list:
65 | return lang_obs
66 | else:
67 | return ". ".join(lang_obs)
68 |
--------------------------------------------------------------------------------
/static/js/index.js:
--------------------------------------------------------------------------------
1 | window.HELP_IMPROVE_VIDEOJS = false;
2 |
3 | var INTERP_BASE = "./static/interpolation/stacked";
4 | var NUM_INTERP_FRAMES = 240;
5 |
6 | var interp_images = [];
7 | function preloadInterpolationImages() {
8 | for (var i = 0; i < NUM_INTERP_FRAMES; i++) {
9 | var path = INTERP_BASE + '/' + String(i).padStart(6, '0') + '.jpg';
10 | interp_images[i] = new Image();
11 | interp_images[i].src = path;
12 | }
13 | }
14 |
15 | function setInterpolationImage(i) {
16 | var image = interp_images[i];
17 | image.ondragstart = function() { return false; };
18 | image.oncontextmenu = function() { return false; };
19 | $('#interpolation-image-wrapper').empty().append(image);
20 | }
21 |
22 |
23 | $(document).ready(function() {
24 | // Check for click events on the navbar burger icon
25 | $(".navbar-burger").click(function() {
26 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu"
27 | $(".navbar-burger").toggleClass("is-active");
28 | $(".navbar-menu").toggleClass("is-active");
29 |
30 | });
31 |
32 | var options = {
33 | slidesToScroll: 1,
34 | slidesToShow: 3,
35 | loop: true,
36 | infinite: true,
37 | autoplay: false,
38 | autoplaySpeed: 3000,
39 | }
40 |
41 | // Initialize all div with carousel class
42 | var carousels = bulmaCarousel.attach('.carousel', options);
43 |
44 | // Loop on each carousel initialized
45 | for(var i = 0; i < carousels.length; i++) {
46 | // Add listener to event
47 | carousels[i].on('before:show', state => {
48 | console.log(state);
49 | });
50 | }
51 |
52 | // Access to bulmaCarousel instance of an element
53 | var element = document.querySelector('#my-element');
54 | if (element && element.bulmaCarousel) {
55 | // bulmaCarousel instance is available as element.bulmaCarousel
56 | element.bulmaCarousel.on('before-show', function(state) {
57 | console.log(state);
58 | });
59 | }
60 |
61 | /*var player = document.getElementById('interpolation-video');
62 | player.addEventListener('loadedmetadata', function() {
63 | $('#interpolation-slider').on('input', function(event) {
64 | console.log(this.value, player.duration);
65 | player.currentTime = player.duration / 100 * this.value;
66 | })
67 | }, false);*/
68 | preloadInterpolationImages();
69 |
70 | $('#interpolation-slider').on('input', function(event) {
71 | setInterpolationImage(this.value);
72 | });
73 | setInterpolationImage(0);
74 | $('#interpolation-slider').prop('max', NUM_INTERP_FRAMES - 1);
75 |
76 | bulmaSlider.attach();
77 |
78 | })
79 |
--------------------------------------------------------------------------------
/sso/agent/fewshot.py:
--------------------------------------------------------------------------------
1 | from sso.trajectory import Trajectory
2 | from sso.llm import query_llm
3 | from sso.agent import Agent
4 |
5 |
6 | class FewshotAgent(Agent):
7 |
8 | def __init__(self, fewshot: int = 3, **kwargs):
9 | super().__init__(**kwargs)
10 | self.fewshot = fewshot
11 |
12 | def act(self, trajectory: Trajectory) -> str:
13 | super().act(trajectory)
14 |
15 | # Get next action
16 | return self._act(trajectory)
17 |
18 | def _update_memory(self, trajectory: Trajectory) -> None:
19 | self.memory.insert(trajectory)
20 |
21 | def _act(self, trajectory: Trajectory) -> str:
22 | sub_trajectory = trajectory.slice(-self.max_history, None)
23 |
24 | system_message = "You are playing a text-based game in which you must interact with your surroundings to complete a task.\n\n"
25 | system_message += sub_trajectory.task_description
26 | system_message += "\n\nGiven the state, reflect on what has happened so far, explain your plan to accomplish the task and then output the next action to execute (use one of the action templates below)."
27 | system_message += "\n\nFor example:\nThe last action had the effect of... To accomplish the task, I will need to...\nCurrent subgoal: [subgoal]\nNext action: [action]"
28 | if sub_trajectory[-1].action_prompt is not None:
29 | system_message += "\n\n" + sub_trajectory[-1].action_prompt
30 |
31 | if len(self.memory.trajectories) > 0:
32 | examples = self.memory.get_memories(n=self.fewshot)
33 | system_message += "\n\nUse the following example trajector{} to help you accomplish the task:".format(
34 | "ies" if len(examples) > 1 else "y"
35 | )
36 | for traj in examples:
37 | system_message += "\n\n" + traj.build_string(trim_state=self.trim_states)
38 |
39 | messages = [dict(role="system", content=system_message)]
40 |
41 | state_strings = sub_trajectory.build_string(use_task=False, use_actions=False, as_list=True, trim_state=self.trim_states)
42 | for i in range(len(sub_trajectory)):
43 | state = sub_trajectory[i]
44 |
45 | if i > 0:
46 | action_start = state.last_generation.lower().find("next action:")
47 | if action_start != -1:
48 | action_start += len("next action:")
49 | action_text = state.last_generation[:action_start+1] + " " + state.last_action
50 | else:
51 | action_text = "To accomplish the task, I will need to {}. Next action: {}".format(state.last_action, state.last_action)
52 | messages.append(dict(role="assistant", content=action_text))
53 |
54 | prompt = state_strings[i]
55 | messages.append(dict(role="user", content=prompt))
56 |
57 | response = query_llm(messages)
58 | self.log("action_messages", messages)
59 | self.log("action_generation", response)
60 | return response
61 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Skill Set Optimization
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
147 | Large language models (LLMs) have recently been used for sequential decision making in interactive environments. However, leveraging environment reward signals for continual LLM actor improvement is not straightforward. We propose Skill Set Optimization (SSO) for improving LLM actor performance through constructing and refining sets of transferable skills. SSO constructs skills by extracting common subtrajectories with high rewards and generating subgoals and instructions to represent each skill. These skills are provided to the LLM actor in-context to reinforce behaviors with high rewards. Then, SSO further refines the skill set by pruning skills that do not continue to result in high rewards. We evaluate our method in the classic videogame NetHack and the text environment ScienceWorld to demonstrate SSO's ability to optimize a set of skills and perform in-context policy improvement. SSO outperforms baselines by 40% in our custom NetHack task and outperforms the previous state-of-the-art in ScienceWorld by 35%. 148 |
149 |150 |
151 |
163 | 165 | Like other continual learning methods*, SSO uses in-context "memories" with information about the task and environment to improve the LLM actor's policy. The memories that SSO generates are instructions for achieving subgoals we call skills. Unlike previous work, SSO continuously evaluates generated memories, creates memories that define modular subgoals, and facilitates memory retrieval. 166 |
167 |168 | * e.g. Voyager, ExpeL, and CLIN agents 169 |
170 | 171 |
173 | 175 | Each iteration of SSO includes: 176 |
177 |183 | To construct new skills, we extract potential subtrajectories, score them using discounted reward and similarity and length, sample an updated skill set using beam search, and generate subgoals and instructions for each new skill. We refine the constructed skill set by filtering skills that did not result in high rewards when used previous trajectories. Then, when providing skills in-context, we retrieve only the most relevant skills based on cosine similarity of skill initial states and the current environment state. 184 |
185 | 186 |
188 | 191 | Each row of this plot shows all of the skills created in the cooresponding iteration and when they were executed. On both ScienceWorld and NetHack, SSO prunes most new skills after few iterations. The LLM actor uses more recent skills as it continues to improve at the task and learn new skills and improve old skills. 192 |
193 | 194 |
196 | 198 | SSO outperforms previous state-of-the-art in ScienceWorld by 35% in task adaptation and 14% in task transfer. Learned and reinforced skills such as those listed below provide knowledge of subgoals that are transferable across tasks. 199 |
200 |
203 | You move to the kitchen
204 |
|
209 |
210 | The stove is turned on. on the stove is: a substance called liquid [substance]
211 |
|
218 |
@article{nottingham2024sso,
231 | author = "Nottingham, Kolby and Majumder, Bodhisattwa Prasad and Dalvi Mishra, Bhavana and Singh, Sameer and Clark, Peter and Fox, Roy",
232 | title = "Skill Set Optimization: Reinforcing Language Model Behavior via Transferable Skills",
233 | journal = "arXiv",
234 | year = "2024",
235 | url = "https://arxiv.org/abs/2402.03244"
236 | }
237 |