├── tests ├── __init__.py ├── observations │ ├── __init__.py │ ├── serde │ │ ├── __init__.py │ │ ├── types │ │ │ ├── __init__.py │ │ │ ├── test_rewards.py │ │ │ ├── test_base.py │ │ │ ├── test_relics.py │ │ │ ├── test_potions.py │ │ │ └── test_cards.py │ │ ├── test_card_reward.py │ │ ├── test_shop.py │ │ ├── test_combat_reward.py │ │ ├── test_persistent.py │ │ ├── test_combat.py │ │ └── test_campfire.py │ └── test_enemy_serialization.py ├── test_card_reward_in_combat.py ├── test_screenshotting.py ├── env │ ├── test_single_combat_env.py │ └── test_rebooting.py └── conftest.py ├── gym_sts ├── __init__.py ├── data │ ├── __init__.py │ ├── state_log_loader.py │ └── state_logger.py ├── envs │ ├── __init__.py │ ├── types.py │ ├── single_combat.py │ ├── utils.py │ └── action_validation.py ├── spaces │ ├── __init__.py │ ├── constants │ │ ├── __init__.py │ │ ├── campfire.py │ │ ├── shop.py │ │ ├── rewards.py │ │ ├── base.py │ │ ├── map.py │ │ ├── events.py │ │ ├── combat.py │ │ └── potions.py │ ├── observations │ │ ├── __init__.py │ │ ├── components │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── campfire.py │ │ │ ├── event.py │ │ │ ├── card_reward.py │ │ │ ├── combat_reward.py │ │ │ ├── shop.py │ │ │ ├── persistent.py │ │ │ └── combat.py │ │ ├── types │ │ │ ├── __init__.py │ │ │ ├── campfire.py │ │ │ ├── map.py │ │ │ ├── potions.py │ │ │ ├── relics.py │ │ │ ├── cards.py │ │ │ ├── rewards.py │ │ │ └── base.py │ │ ├── utils.py │ │ ├── serializers.py │ │ ├── spaces.py │ │ └── observations.py │ ├── actions.py │ └── data.py ├── build │ ├── preferences │ │ ├── STSDaily │ │ ├── STSAchievements │ │ ├── STSDataDefect │ │ ├── STSDataVagabond │ │ ├── STSDataWatcher │ │ ├── STSSeenBosses │ │ ├── STSDataTheSilent │ │ ├── STSInputSettings │ │ ├── STSBetaCardPreference │ │ ├── STSInputSettings_Controller │ │ ├── STSSound │ │ ├── STSSaveSlots │ │ ├── STSTips │ │ ├── STSPlayer │ │ ├── STSGameplaySettings │ │ ├── STSUnlockProgress │ │ ├── STSUnlocks │ │ ├── STSSeenRelics │ │ └── STSSeenCards │ ├── asound.conf │ ├── info.displayconfig │ ├── communication_mod.config.properties │ ├── superfastmode.config.properties │ ├── pipe_to_host.sh │ ├── pipe_locally.sh │ └── Dockerfile ├── communication │ ├── __init__.py │ ├── receiver.py │ ├── sender.py │ └── communicator.py ├── env_repl.py ├── exceptions.py ├── constants.py ├── rl │ ├── utils.py │ ├── action_masking.py │ ├── metrics.py │ └── mvp.py ├── perf.py ├── runner.py └── test_valid_actions.py ├── setup.cfg ├── out └── README.md ├── .gitignore ├── .pre-commit-config.yaml ├── pyproject.toml └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gym_sts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gym_sts/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gym_sts/envs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gym_sts/spaces/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/observations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gym_sts/spaces/constants/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/observations/serde/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSDaily: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /tests/observations/serde/types/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSAchievements: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSDataDefect: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSDataVagabond: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSDataWatcher: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSSeenBosses: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSDataTheSilent: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSInputSettings: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSBetaCardPreference: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSInputSettings_Controller: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /gym_sts/build/asound.conf: -------------------------------------------------------------------------------- 1 | defaults.pcm.card 1 2 | defaults.ctl.card 1 3 | -------------------------------------------------------------------------------- /gym_sts/build/info.displayconfig: -------------------------------------------------------------------------------- 1 | 1024 2 | 576 3 | 1000000000 4 | false 5 | false 6 | false 7 | -------------------------------------------------------------------------------- /gym_sts/communication/__init__.py: -------------------------------------------------------------------------------- 1 | from gym_sts.communication.communicator import Communicator # noqa: F401 2 | -------------------------------------------------------------------------------- /gym_sts/build/communication_mod.config.properties: -------------------------------------------------------------------------------- 1 | command=/game/pipe_to_host.sh 2 | runAtGameStart=true 3 | verbose=true 4 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSSound: -------------------------------------------------------------------------------- 1 | { 2 | "Ambience On": "true", 3 | "Mute in Bg": "true", 4 | "Master Volume": "0.0" 5 | } 6 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/__init__.py: -------------------------------------------------------------------------------- 1 | from .observations import OBSERVATION_SPACE, Observation, ObservationError # noqa: F401 2 | -------------------------------------------------------------------------------- /gym_sts/spaces/constants/campfire.py: -------------------------------------------------------------------------------- 1 | MAX_NUM_OPTIONS = 5 # Rest, smith, recall, and 2 of the 3 relic actions 2 | LOG_NUM_OPTIONS = 3 3 | -------------------------------------------------------------------------------- /out/README.md: -------------------------------------------------------------------------------- 1 | # Output 2 | 3 | This is a volume exposed to the docker container. Any files here are sent to /out in the docker container. 4 | -------------------------------------------------------------------------------- /gym_sts/build/superfastmode.config.properties: -------------------------------------------------------------------------------- 1 | isDeltaMultiplied=true 2 | deltaMultiplier=100.0 3 | EXISTS=YES INDEED I EXIST 4 | isInstantLerp=true 5 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSSaveSlots: -------------------------------------------------------------------------------- 1 | { 2 | "DEFAULT_SLOT": "0", 3 | "PROFILE_NAME": "test", 4 | "COMPLETION": "34.011845", 5 | "PLAYTIME": "819" 6 | } 7 | -------------------------------------------------------------------------------- /gym_sts/spaces/constants/shop.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | _SHOP_MAX_PRICE = 1000 5 | 6 | SHOP_CARD_COUNT = 7 7 | SHOP_RELIC_COUNT = 3 8 | SHOP_POTION_COUNT = 3 9 | SHOP_LOG_MAX_PRICE = math.ceil(math.log(_SHOP_MAX_PRICE, 2)) 10 | -------------------------------------------------------------------------------- /gym_sts/env_repl.py: -------------------------------------------------------------------------------- 1 | # run with ipython -i gym_sts/env_repl.py 2 | 3 | from gym_sts.envs.base import SlayTheSpireGymEnv 4 | 5 | 6 | SlayTheSpireGymEnv.build_image() 7 | env = SlayTheSpireGymEnv("lib", "mods", "out", headless=True) 8 | obs = env.reset() 9 | -------------------------------------------------------------------------------- /gym_sts/exceptions.py: -------------------------------------------------------------------------------- 1 | class StSError(Exception): 2 | """ 3 | General class of StS exceptions. 4 | """ 5 | 6 | pass 7 | 8 | 9 | class StSTimeoutError(TimeoutError, StSError): 10 | """ 11 | Typically indicates a CommunicationMod response was not received in time. 12 | """ 13 | 14 | pass 15 | -------------------------------------------------------------------------------- /gym_sts/build/pipe_to_host.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | INPUT="out/stsai_input" 3 | OUTPUT="out/stsai_output" 4 | 5 | function cleanup() { 6 | kill $BG_PID 7 | } 8 | 9 | trap cleanup EXIT 10 | 11 | cat $INPUT & 12 | BG_PID=$! 13 | 14 | # Sleep to allow time for the background process to start 15 | sleep 0.2 16 | 17 | cat > $OUTPUT 18 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSTips: -------------------------------------------------------------------------------- 1 | { 2 | "NEOW_SKIP": "true", 3 | "COMBAT_TIP": "true", 4 | "SHUFFLE_TIP": "true", 5 | "POTION_TIP": "true", 6 | "INTENT_TIP": "true", 7 | "POWER_TIP": "true", 8 | "RELIC_TIP": "true", 9 | "ENERGY_USE_TIP": "true", 10 | "NEOW_INTRO": "true", 11 | "CARD_REWARD_TIP": "true" 12 | } 13 | -------------------------------------------------------------------------------- /gym_sts/constants.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | PROJECT_ROOT = Path(__file__).parent.absolute() 5 | 6 | JAVA_INSTALL = "/usr/bin/java" 7 | MTS_JAR = "ModTheSpire.jar" 8 | EXTRA_ARGS = [ 9 | "--skip-launcher", 10 | "--skip-intro", 11 | "--mods", 12 | "basemod,CommunicationMod,superfastmode", 13 | ] 14 | 15 | DOCKER_IMAGE_TAG = "sts" 16 | -------------------------------------------------------------------------------- /gym_sts/build/pipe_locally.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Pipe CommunicationMod IO to FIFOs provided by the host 4 | 5 | INPUT=$1 6 | OUTPUT=$2 7 | 8 | function cleanup() { 9 | kill $BG_PID 10 | } 11 | 12 | trap cleanup EXIT 13 | 14 | cat $INPUT & 15 | BG_PID=$! 16 | 17 | # Sleep to allow time for the background process to start 18 | sleep 0.2 19 | 20 | cat > $OUTPUT 21 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSPlayer: -------------------------------------------------------------------------------- 1 | { 2 | "alias": "test", 3 | "name": "test", 4 | "IRONCLAD_SPIRITS": "1", 5 | "THE_SILENT_SPIRITS": "1", 6 | "DEFECT_SPIRITS": "1", 7 | "DMG_DEALT": "12828", 8 | "IRONCLAD_WIN": "true", 9 | "THE_SILENT_WIN": "true", 10 | "NOTE_CARD": "Zap", 11 | "NOTE_UPGRADE": "0", 12 | "DEFECT_WIN": "true", 13 | "WATCHER_SPIRITS": "1" 14 | } 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | 3 | # Don't check in 3rd party jars (game + mods) 4 | lib 5 | mods 6 | 7 | # Communication channel with docker containers 8 | out 9 | 10 | # Temp directory for running the game locally 11 | tmp 12 | 13 | # Never check the game in, wherever it may be in the tree 14 | desktop-1.0.jar 15 | 16 | # Don't check in hypothesis example database at this time 17 | .hypothesis 18 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/components/__init__.py: -------------------------------------------------------------------------------- 1 | from .campfire import CampfireObs # noqa: F401 2 | from .card_reward import CardRewardObs # noqa: F401 3 | from .combat import CombatObs # noqa: F401 4 | from .combat_reward import CombatRewardObs # noqa: F401 5 | from .event import EventStateObs # noqa: F401 6 | from .persistent import PersistentStateObs # noqa: F401 7 | from .shop import ShopObs # noqa: F401 8 | -------------------------------------------------------------------------------- /gym_sts/rl/utils.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | from gymnasium import spaces 4 | 5 | 6 | 7 | def assert_contains(space: spaces.Space, element: tp.Any): 8 | """Use with post-mortem debugging to see where in the space the error is.""" 9 | if isinstance(space, spaces.Dict): 10 | assert isinstance(element, tp.Mapping) 11 | 12 | for key, subspace in space.items(): 13 | assert_contains(subspace, element[key]) 14 | assert space.contains(element) 15 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSGameplaySettings: -------------------------------------------------------------------------------- 1 | { 2 | "Summed Damage": "true", 3 | "Blocked Damage": "true", 4 | "Hand Confirmation": "false", 5 | "Upload Data": "false", 6 | "Particle Effects": "true", 7 | "Fast Mode": "true", 8 | "Show Card keys": "false", 9 | "Bigger Text": "false", 10 | "Long-press Enabled": "false", 11 | "Screen Shake": "false", 12 | "Playtester Art": "false", 13 | "Controller Enabled": "true", 14 | "Touchscreen Enabled": "false", 15 | "LANGUAGE": "ENG", 16 | "Ascension Mode Default": "true" 17 | } 18 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/types/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BinaryArray, Effect, Enemy, Health, Keys, Orb # noqa: F401 2 | from .campfire import CampfireChoice # noqa: F401 3 | from .cards import Card, HandCard, ShopCard # noqa: F401 4 | from .map import Map # noqa: F401 5 | from .potions import Potion, PotionBase, ShopPotion # noqa: F401 6 | from .relics import Relic, RelicBase, ShopRelic # noqa: F401 7 | from .rewards import ( # noqa: F401 8 | CardReward, 9 | GoldReward, 10 | KeyReward, 11 | PotionReward, 12 | RelicReward, 13 | Reward, 14 | ) 15 | -------------------------------------------------------------------------------- /tests/observations/serde/test_card_reward.py: -------------------------------------------------------------------------------- 1 | from gym_sts.envs.base import SlayTheSpireGymEnv 2 | 3 | 4 | def test_card_reward_serde(env: SlayTheSpireGymEnv): 5 | env.reset(seed=42) 6 | 7 | env.communicator.basemod("relic add Singing_Bowl") 8 | env.communicator.basemod("fight 2_Orb_Walkers") 9 | env.communicator.basemod("kill all") 10 | 11 | # Open the card reward 12 | env.step(5) 13 | 14 | obs = env.observe(add_to_cache=True) 15 | orig_state = obs.card_reward_state 16 | ser = orig_state.serialize() 17 | de = obs.card_reward_state.deserialize(ser) 18 | 19 | assert orig_state == de 20 | -------------------------------------------------------------------------------- /gym_sts/envs/types.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from pydantic import BaseModel, validator 4 | 5 | 6 | class ResetParams(BaseModel): 7 | """ 8 | Pydantic model for validating reset() arguments. 9 | """ 10 | 11 | seed: Optional[int] 12 | sts_seed: Optional[str] 13 | rng_state: Optional[tuple] 14 | reboot: bool = False 15 | 16 | @validator("rng_state") 17 | def seed_and_rng_state_are_mutually_exclusive(cls, v, values, **kwargs): 18 | if values.get("seed") is not None and v is not None: 19 | raise ValueError("seed and rng_state cannot both be provided") 20 | return v 21 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSUnlockProgress: -------------------------------------------------------------------------------- 1 | { 2 | "IRONCLADProgress": "326", 3 | "IRONCLADTotalScore": "5876", 4 | "IRONCLADHighScore": "876", 5 | "IRONCLADUnlockLevel": "5", 6 | "IRONCLADCurrentCost": "2500", 7 | "THE_SILENTUnlockLevel": "5", 8 | "THE_SILENTProgress": "85", 9 | "THE_SILENTCurrentCost": "2500", 10 | "THE_SILENTTotalScore": "5760", 11 | "THE_SILENTHighScore": "2507", 12 | "DEFECTUnlockLevel": "5", 13 | "DEFECTProgress": "811", 14 | "DEFECTCurrentCost": "2500", 15 | "DEFECTTotalScore": "6361", 16 | "DEFECTHighScore": "1422", 17 | "WATCHERUnlockLevel": "5", 18 | "WATCHERProgress": "1130", 19 | "WATCHERCurrentCost": "2500", 20 | "WATCHERTotalScore": "6680", 21 | "WATCHERHighScore": "1480" 22 | } 23 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v3.2.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-added-large-files 11 | - repo: https://github.com/psf/black 12 | rev: 22.6.0 13 | hooks: 14 | - id: black 15 | - repo: https://github.com/PyCQA/flake8 16 | rev: 5.0.4 17 | hooks: 18 | - id: flake8 19 | - repo: https://github.com/pycqa/isort 20 | rev: 5.12.0 21 | hooks: 22 | - id: isort 23 | - repo: https://github.com/pre-commit/mirrors-mypy 24 | rev: v0.971 25 | hooks: 26 | - id: mypy 27 | additional_dependencies: ["pydantic>=1.9.2,<2.0", "numpy>=1.23,<2.0"] 28 | -------------------------------------------------------------------------------- /gym_sts/rl/action_masking.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import tensorflow as tf 4 | from ray.rllib.models import ModelCatalog 5 | 6 | # from ray.rllib.models.tf import tf_modelv2 7 | from ray.rllib.models.tf import fcnet 8 | 9 | 10 | class MaskedModel(fcnet.FullyConnectedNetwork): 11 | def forward( 12 | self, 13 | input_dict: dict[str, tf.Tensor], 14 | state: list[tf.Tensor], 15 | seq_lens: tf.Tensor, 16 | ) -> tp.Tuple[tf.Tensor, list[tf.Tensor]]: 17 | logits, state = super().forward(input_dict, state, seq_lens) 18 | 19 | mask = input_dict["obs"]["valid_action_mask"] 20 | mask = tf.cast(mask, tf.bool) 21 | logits = tf.where(mask, logits, tf.float32.min) 22 | 23 | return logits, state 24 | 25 | 26 | def register(): 27 | ModelCatalog.register_custom_model("masked", MaskedModel) 28 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import numpy.typing as npt 5 | 6 | 7 | def to_binary_array(n: int, digits: int) -> np.ndarray: 8 | array = [0] * digits 9 | 10 | idx = 0 11 | n_copy = n 12 | while n_copy > 0: 13 | if idx >= digits: 14 | raise ValueError( 15 | f"{n} is too large to represent with {digits} binary digits" 16 | ) 17 | 18 | n_copy, r = divmod(n_copy, 2) 19 | if r > 0: 20 | array[idx] = 1 21 | idx += 1 22 | 23 | return np.array(array, dtype=np.uint8) 24 | 25 | 26 | def from_binary_array(array: Union[list[int], npt.NDArray[np.uint]]) -> int: 27 | total = 0 28 | place_value = 1 29 | 30 | for digit in array: 31 | if digit == 1: 32 | total += place_value 33 | 34 | place_value *= 2 35 | 36 | return total 37 | -------------------------------------------------------------------------------- /gym_sts/spaces/constants/rewards.py: -------------------------------------------------------------------------------- 1 | import math 2 | from enum import IntEnum 3 | 4 | from .potions import NUM_POTIONS 5 | from .relics import NUM_RELICS 6 | 7 | 8 | class RewardType(IntEnum): 9 | NONE = 0 # An empty reward slot 10 | CARD = 1 11 | GOLD = 2 12 | KEY = 3 13 | POTION = 4 14 | RELIC = 5 15 | 16 | 17 | NUM_REWARD_TYPES = len(RewardType) 18 | 19 | REWARD_CARD_COUNT = 4 # Default of 3, +1 for Question Card 20 | 21 | # Boss gold reward max * golden idol bonus * buffer in case I'm wrong 22 | _COMBAT_REWARD_MAX_GOLD = int(105 * 1.25 * 1.25) 23 | _COMBAT_REWARD_MAX_POTION = NUM_POTIONS 24 | _COMBAT_REWARD_MAX_RELIC = NUM_RELICS 25 | _COMBAT_REWARD_MAX_ID = max( 26 | _COMBAT_REWARD_MAX_GOLD, _COMBAT_REWARD_MAX_POTION, _COMBAT_REWARD_MAX_RELIC 27 | ) 28 | COMBAT_REWARD_LOG_MAX_ID = math.ceil(math.log(_COMBAT_REWARD_MAX_ID, 2)) 29 | 30 | # (Card + gold + potion + 2 relics (black star) + key) * buffer in case I'm wrong 31 | MAX_NUM_REWARDS = int((1 + 1 + 1 + 2 + 1) * 1.25) 32 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/serializers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.typing as npt 3 | 4 | from gym_sts.spaces.constants import cards as card_consts 5 | from gym_sts.spaces.constants import combat as combat_consts 6 | from gym_sts.spaces.observations import types 7 | 8 | 9 | def serialize_cards(cards: list[types.Card]) -> npt.NDArray[np.uint]: 10 | # TODO handle Searing Blow, which can be upgraded unlimited times 11 | serialized = [0] * card_consts.NUM_CARDS_WITH_UPGRADES 12 | for card in cards: 13 | card_idx = card.serialize(discrete=True) 14 | 15 | if serialized[card_idx] < card_consts.MAX_COPIES_OF_CARD: 16 | serialized[card_idx] += 1 17 | 18 | return np.array(serialized, dtype=np.uint8) 19 | 20 | 21 | def serialize_orbs(orbs: list[types.Orb]) -> npt.NDArray[np.uint]: 22 | serialized = np.array([types.Orb.serialize_empty()] * combat_consts.MAX_ORB_SLOTS) 23 | 24 | for i, orb in enumerate(orbs): 25 | serialized[i] = orb.serialize() 26 | 27 | return serialized 28 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "gym-sts" 3 | version = "0.1.0" 4 | description = "Gym environment for Slay the Spire" 5 | authors = ["Zeus Kronion "] 6 | license = "MIT" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.9, <3.11" # Tensorflow requires <3.11 10 | docker = "^5.0.3" 11 | pydantic = "^1.9.2" 12 | ray = {extras = ["rllib"], version = "^2.4.0"} 13 | fancyflags = "^1.1" 14 | absl-py = "^1.3.0" 15 | wandb = "^0.13.11" 16 | dm-tree = "^0.1.7" 17 | tensorflow = "^2.12.0" 18 | gymnasium = "^0.26.0" 19 | 20 | [tool.poetry.dev-dependencies] 21 | ipython = "^8.4.0" 22 | pre-commit = "^2.20.0" 23 | flake8 = "^5.0.4" 24 | black = "^22.6.0" 25 | mypy = "^0.991" 26 | isort = "^5.10.1" 27 | flake8-isort = "^4.2.0" 28 | 29 | [tool.poetry.group.dev.dependencies] 30 | pytest = "^7.1.3" 31 | hypothesis = "^6.58.1" 32 | 33 | [build-system] 34 | requires = ["poetry-core>=1.0.0"] 35 | build-backend = "poetry.core.masonry.api" 36 | 37 | [tool.isort] 38 | profile = "black" 39 | include_trailing_comma = true 40 | lines_after_imports = 2 41 | -------------------------------------------------------------------------------- /tests/test_card_reward_in_combat.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from gym_sts.envs.action_validation import get_valid 4 | from gym_sts.spaces import actions 5 | 6 | 7 | def test_card_reward_valid_in_combat(env): 8 | env.reset(seed=42) 9 | 10 | env.communicator.basemod("deck remove all") 11 | env.communicator.basemod("deck add Discovery") 12 | time.sleep(0.5) # Wait briefly for card adding animation to complete 13 | env.communicator.basemod("fight Jaw_Worm") 14 | 15 | # At this point the only card in hand is Discovery 16 | # TODO make action selection easier 17 | action_id = actions.ACTIONS.index( 18 | actions.PlayCard(card_position=1, target_index=None) 19 | ) 20 | _, _, _, info = env.step(action_id) 21 | obs = info["observation"] 22 | assert not obs.has_error 23 | expected_valid_actions = set( 24 | [ 25 | actions.Choose(choice_index=0), 26 | actions.Choose(choice_index=1), 27 | actions.Choose(choice_index=2), 28 | ] 29 | ) 30 | assert set(get_valid(obs)) == expected_valid_actions 31 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/spaces.py: -------------------------------------------------------------------------------- 1 | from gymnasium.spaces import Dict, Discrete, MultiBinary, MultiDiscrete, Tuple 2 | 3 | import gym_sts.spaces.constants.base as base_consts 4 | import gym_sts.spaces.constants.cards as card_consts 5 | import gym_sts.spaces.constants.combat as combat_consts 6 | 7 | 8 | def generate_card_space(): 9 | # Generally beyond some number of cards you don't actually care 10 | # how many cards you have 11 | # But this could be optimized 12 | return MultiDiscrete( 13 | [card_consts.MAX_COPIES_OF_CARD + 1] * card_consts.NUM_CARDS_WITH_UPGRADES 14 | ) 15 | 16 | 17 | def generate_effect_space(): 18 | effect_space = Dict( 19 | { 20 | "sign": Discrete(2), 21 | "value": MultiBinary(combat_consts.LOG_MAX_EFFECT), 22 | } 23 | ) 24 | return Tuple([effect_space] * combat_consts.NUM_EFFECTS) 25 | 26 | 27 | def generate_health_space(): 28 | return Dict( 29 | { 30 | "hp": MultiBinary(base_consts.LOG_MAX_HP), 31 | "max_hp": MultiBinary(base_consts.LOG_MAX_HP), 32 | } 33 | ) 34 | -------------------------------------------------------------------------------- /gym_sts/data/state_log_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import TextIO 3 | 4 | from pydantic import BaseModel 5 | 6 | from gym_sts.spaces.actions import ACTIONS 7 | from gym_sts.spaces.observations import Observation 8 | 9 | 10 | class Step(BaseModel): 11 | state_before: dict 12 | action: int 13 | state_after: dict 14 | 15 | 16 | class StateLogLoader: 17 | def __init__(self): 18 | self.observations = [] 19 | self.steps = [] 20 | 21 | def load_file(self, fh: TextIO): 22 | objs = json.load(fh) 23 | 24 | action_strings = [act.to_command() for act in ACTIONS] 25 | 26 | prev_obs = None 27 | 28 | for obj in objs: 29 | cur_obs = Observation(obj["state_after"]).serialize() 30 | 31 | self.observations.append(cur_obs) 32 | 33 | # Store step data 34 | if prev_obs and obj["action"]: 35 | action = action_strings.index(obj["action"]) 36 | step = Step(state_before=prev_obs, action=action, state_after=cur_obs) 37 | self.steps.append(step) 38 | 39 | prev_obs = cur_obs 40 | -------------------------------------------------------------------------------- /tests/observations/serde/test_shop.py: -------------------------------------------------------------------------------- 1 | from gym_sts.envs.base import SlayTheSpireGymEnv 2 | 3 | 4 | def test_shop_serde(env: SlayTheSpireGymEnv): 5 | env.reset(seed=43) 6 | 7 | # Neow event 8 | env.step(3) 9 | env.step(4) 10 | env.step(3) 11 | 12 | # Skip a combat 13 | env.step(3) 14 | env.communicator.basemod("kill all") 15 | env.step(2) 16 | 17 | # Shop 18 | env.step(3) 19 | env.step(3) 20 | 21 | obs = env.observe(add_to_cache=True) 22 | orig_state = obs.shop_state 23 | ser = orig_state.serialize() 24 | de = obs.shop_state.deserialize(ser) 25 | 26 | assert orig_state == de 27 | 28 | # Add plenty of gold so we can buy one of everything 29 | env.communicator.basemod("gold add 3000") 30 | 31 | # Buy a few things 32 | env.step(6) # A card 33 | env.step(15) # A potion 34 | env.step(12) # A relic 35 | 36 | # A card removal 37 | env.step(3) 38 | env.step(3) 39 | env.step(2) 40 | 41 | obs = env.observe(add_to_cache=True) 42 | orig_state = obs.shop_state 43 | ser = orig_state.serialize() 44 | de = obs.shop_state.deserialize(ser) 45 | 46 | assert orig_state == de 47 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/types/campfire.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from enum import Enum 4 | from typing import Union 5 | 6 | import numpy as np 7 | from gymnasium.spaces import MultiBinary 8 | 9 | from gym_sts.spaces.constants.campfire import LOG_NUM_OPTIONS 10 | from gym_sts.spaces.observations import utils 11 | 12 | from .base import BinaryArray 13 | 14 | 15 | class CampfireChoice(str, Enum): 16 | EMPTY = "EMPTY" # Indicates the absence of a choice 17 | REST = "rest" 18 | SMITH = "smith" 19 | DIG = "dig" 20 | LIFT = "lift" 21 | TOKE = "toke" 22 | RECALL = "recall" 23 | 24 | @staticmethod 25 | def space(): 26 | return MultiBinary(LOG_NUM_OPTIONS) 27 | 28 | @classmethod 29 | def serialize_empty(cls) -> BinaryArray: 30 | return cls.EMPTY.serialize() 31 | 32 | def serialize(self) -> BinaryArray: 33 | idx = list(CampfireChoice).index(self) 34 | return utils.to_binary_array(idx, LOG_NUM_OPTIONS) 35 | 36 | @classmethod 37 | def deserialize(cls, idx: Union[int, BinaryArray]) -> CampfireChoice: 38 | if isinstance(idx, np.ndarray): 39 | idx = utils.from_binary_array(idx) 40 | 41 | return list(cls)[idx] 42 | -------------------------------------------------------------------------------- /tests/observations/serde/test_combat_reward.py: -------------------------------------------------------------------------------- 1 | from gym_sts.envs.base import SlayTheSpireGymEnv 2 | 3 | 4 | def test_combat_serde(env: SlayTheSpireGymEnv): 5 | env.reset(seed=46) 6 | 7 | # Neow event 8 | env.step(3) 9 | env.step(4) 10 | env.step(3) 11 | 12 | # Skip a bunch of combats (and one shop) 13 | env.step(3) 14 | env.communicator.basemod("kill all") 15 | env.step(2) 16 | env.step(3) 17 | env.communicator.basemod("kill all") 18 | env.step(2) 19 | env.step(3) 20 | env.communicator.basemod("kill all") 21 | env.step(2) 22 | env.step(3) 23 | env.step(2) 24 | env.step(3) 25 | env.communicator.basemod("kill all") 26 | env.step(2) 27 | env.step(4) 28 | env.communicator.basemod("kill all") 29 | env.step(2) 30 | 31 | # Event 32 | env.step(3) 33 | env.step(3) 34 | env.step(3) 35 | env.step(2) 36 | env.step(3) 37 | 38 | # Burning elite 39 | env.step(3) 40 | env.communicator.basemod("kill all") 41 | 42 | obs = env.observe(add_to_cache=True) 43 | orig_state = obs.combat_reward_state 44 | ser = orig_state.serialize() 45 | de = obs.combat_reward_state.deserialize(ser) 46 | 47 | assert orig_state == de 48 | -------------------------------------------------------------------------------- /gym_sts/spaces/constants/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | from enum import Enum 3 | 4 | 5 | MAX_HP = 999 6 | NUM_FLOORS = 55 7 | 8 | LOG_MAX_GOLD = 12 9 | LOG_MAX_HP = math.ceil(math.log(MAX_HP, 2)) 10 | LOG_NUM_FLOORS = math.ceil(math.log(NUM_FLOORS, 2)) 11 | 12 | ALL_KEYS = [ 13 | "EMERALD", 14 | "RUBY", 15 | "SAPPHIRE", 16 | ] 17 | NUM_KEYS = len(ALL_KEYS) 18 | 19 | # I don't know if 15 is enough, I know the card flipping game has at least 12 20 | NUM_CHOICES = 16 21 | 22 | 23 | class ScreenType(str, Enum): 24 | EMPTY = "EMPTY" # Indicates the absence of a screen type 25 | BOSS_REWARD = "BOSS_REWARD" # The contents of the boss chest 26 | CARD_REWARD = "CARD_REWARD" 27 | CHEST = "CHEST" 28 | COMBAT_REWARD = "COMBAT_REWARD" 29 | COMPLETE = "COMPLETE" # The screen immediately after defeating the Act 3/4 boss? 30 | EVENT = "EVENT" 31 | FTUE = "FTUE" 32 | GAME_OVER = "GAME_OVER" 33 | GRID = "GRID" # The contents of card piles, e.g. the discard 34 | HAND_SELECT = "HAND_SELECT" 35 | MAIN_MENU = "MAIN_MENU" 36 | MAP = "MAP" 37 | NONE = "NONE" # Has several meanings, e.g. combat 38 | REST = "REST" 39 | SHOP_ROOM = "SHOP_ROOM" # The room containing the merchant 40 | SHOP_SCREEN = "SHOP_SCREEN" # The actual shopping menu 41 | -------------------------------------------------------------------------------- /gym_sts/perf.py: -------------------------------------------------------------------------------- 1 | """Measures steps per second of the StS env. 2 | 3 | Currently I get about ~1 sps. 4 | 5 | TODO: try with and without the superfast mod. 6 | """ 7 | 8 | import argparse 9 | import random 10 | import time 11 | 12 | from gym_sts.envs.base import SlayTheSpireGymEnv 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("lib_dir") 18 | parser.add_argument("mods_dir") 19 | parser.add_argument("--build_image", action="store_true") 20 | parser.add_argument("--runtime", default=30, type=int) 21 | args = parser.parse_args() 22 | 23 | if args.build_image: 24 | SlayTheSpireGymEnv.build_image() 25 | env = SlayTheSpireGymEnv(args.lib_dir, args.mods_dir, headless=True) 26 | env.reset(seed=42) 27 | rng = random.Random(42) 28 | 29 | num_steps = 0 30 | start_time = time.perf_counter() 31 | 32 | while True: 33 | action = rng.choice(env.valid_actions()) 34 | _, _, done, _ = env.step(action._id) 35 | if done: 36 | env.reset() 37 | 38 | num_steps += 1 39 | run_time = time.perf_counter() - start_time 40 | if run_time >= args.runtime: 41 | break 42 | 43 | sps = num_steps / run_time 44 | print(f"sps: {sps}") 45 | 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /gym_sts/spaces/constants/map.py: -------------------------------------------------------------------------------- 1 | # You can read more about the structure of the map here: 2 | # https://kosgames.com/slay-the-spire-map-generation-guide-26769/ 3 | ALL_MAP_LOCATIONS = [ 4 | "NONE", # Indicates the absence of node 5 | "M", # Monster 6 | "?", # Unknown 7 | "$", # Shop 8 | "E", # Elite 9 | "B", # Burning Elite. Note that this symbol isn't actually used by the game. 10 | "T", # Treasure 11 | "R", # Rest site 12 | ] 13 | NUM_MAP_LOCATIONS = len(ALL_MAP_LOCATIONS) 14 | NUM_MAP_NODES_PER_ROW = 7 15 | NUM_MAP_ROWS = 15 16 | NUM_MAP_NODES = NUM_MAP_NODES_PER_ROW * NUM_MAP_ROWS 17 | # Nodes can only have edges to endpoints in the same column, or one column to the 18 | # left or right. Thus, we store three bits per node, representing the presence of an 19 | # edge to the left, center, and right. 20 | NUM_MAP_EDGES_PER_NODE = 3 # Max branching factor from one layer to the next 21 | NUM_MAP_EDGES = NUM_MAP_NODES_PER_ROW * NUM_MAP_EDGES_PER_NODE * (NUM_MAP_ROWS - 1) 22 | 23 | NORMAL_BOSSES = [ 24 | "NONE", # A placeholder for an "empty" observation 25 | "The Guardian", 26 | "Hexaghost", 27 | "Slime Boss", 28 | "Collector", 29 | "Automaton", 30 | "Champ", 31 | "Awakened One", 32 | "Time Eater", 33 | "Donu and Deca", 34 | ] 35 | NUM_NORMAL_BOSSES = len(NORMAL_BOSSES) 36 | -------------------------------------------------------------------------------- /gym_sts/rl/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from ray.rllib.algorithms.callbacks import DefaultCallbacks 4 | from ray.rllib.env import BaseEnv 5 | from ray.rllib.evaluation import Episode, RolloutWorker 6 | from ray.rllib.policy import Policy 7 | 8 | 9 | class StSCustomMetricCallbacks(DefaultCallbacks): 10 | def on_episode_end( 11 | self, 12 | *, 13 | worker: RolloutWorker, 14 | base_env: BaseEnv, 15 | policies: Dict[str, Policy], 16 | episode: Episode, 17 | env_index: int, 18 | **kwargs 19 | ): 20 | subenvs = base_env.get_sub_environments() 21 | assert len(subenvs) == 1 22 | 23 | obs = subenvs[0].observe() 24 | 25 | max_hp = sum(e.max_hp for e in obs.combat_state.enemies) 26 | enemy_hp = sum(e.current_hp for e in obs.combat_state.enemies) 27 | self_hp = obs.persistent_state.hp 28 | self_max_hp = obs.persistent_state.max_hp 29 | 30 | if self_hp == 0 or enemy_hp == 0: 31 | episode.custom_metrics["win_rate"] = 1 if self_hp > 0 else 0 32 | 33 | if enemy_hp == 0: 34 | episode.custom_metrics["win_remaining_hp"] = self_hp / self_max_hp 35 | else: 36 | episode.custom_metrics["lose_enemy_hp"] = ( 37 | enemy_hp / max_hp if max_hp > 0 else 0 38 | ) 39 | -------------------------------------------------------------------------------- /gym_sts/runner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from gym_sts.envs.base import SlayTheSpireGymEnv 4 | from gym_sts.spaces.observations import ObservationError 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("lib_dir") 10 | parser.add_argument("mods_dir") 11 | parser.add_argument("out_dir") 12 | parser.add_argument("--headless", action="store_true") 13 | parser.add_argument("--build_image", action="store_true") 14 | args = parser.parse_args() 15 | 16 | if args.build_image: 17 | SlayTheSpireGymEnv.build_image() 18 | env = SlayTheSpireGymEnv( 19 | args.lib_dir, args.mods_dir, args.out_dir, headless=args.headless 20 | ) 21 | observation = env.reset() 22 | print(observation.state) 23 | 24 | while True: 25 | action = input("Enter an action: ") 26 | if not action: 27 | print("No action given. Defaulting to STATE.") 28 | action = "STATE" 29 | observation = env._do_action(action) 30 | print(observation.state) 31 | try: 32 | commands = observation._available_commands 33 | print("AVAILABLE COMMANDS:") 34 | print(commands) 35 | except ObservationError as e: 36 | print("ERROR") 37 | print(e) 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/components/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class ObsComponent(ABC): 7 | @abstractmethod 8 | def serialize(self): 9 | """ 10 | Convert the component instance into a data structure 11 | conforming to the shape of the component's gym space. 12 | """ 13 | 14 | raise NotImplementedError("Not implemented") 15 | 16 | # TODO @kronion: uncomment once EventStateObs conforms to the API. 17 | # @classmethod 18 | # @abstractmethod 19 | # def deserialize(cls, data): 20 | # """ 21 | # Convert data matching the shape of the component's 22 | # gym space into a new component instance. 23 | # """ 24 | # 25 | # raise NotImplementedError("Not implemented") 26 | 27 | @staticmethod 28 | @abstractmethod 29 | def space(): 30 | """ 31 | Returns the shape of the component's gym space. 32 | """ 33 | 34 | raise NotImplementedError("Not implemented") 35 | 36 | 37 | class PydanticComponent(ObsComponent, BaseModel): 38 | """ 39 | A subclass of ObsComponent that is also a Pydantic model. In cases where you don't 40 | need an __init__() method to perform more complex parsing of the CommunicationMod 41 | state, you can use this class to define a dataclass-style component. 42 | """ 43 | 44 | pass 45 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSUnlocks: -------------------------------------------------------------------------------- 1 | { 2 | "GUARDIAN": "2", 3 | "The Silent": "2", 4 | "GHOST": "2", 5 | "CHAMP": "2", 6 | "Bane": "2", 7 | "Catalyst": "2", 8 | "Corpse Explosion": "2", 9 | "Defect": "2", 10 | "SLIME": "2", 11 | "AUTOMATON": "2", 12 | "DONUT": "2", 13 | "Heavy Blade": "2", 14 | "Spot Weakness": "2", 15 | "Limit Break": "2", 16 | "Watcher": "2", 17 | "Du-Vu Doll": "2", 18 | "Smiling Mask": "2", 19 | "Tiny Chest": "2", 20 | "COLLECTOR": "2", 21 | "WIZARD": "2", 22 | "Rebound": "2", 23 | "Undo": "2", 24 | "Echo Form": "2", 25 | "Prostrate": "2", 26 | "Blasphemy": "2", 27 | "Devotion": "2", 28 | "Cloak And Dagger": "2", 29 | "Accuracy": "2", 30 | "Storm of Steel": "2", 31 | "Omamori": "2", 32 | "Prayer Wheel": "2", 33 | "Shovel": "2", 34 | "Turbo": "2", 35 | "Sunder": "2", 36 | "Meteor Strike": "2", 37 | "Hyperbeam": "2", 38 | "Recycle": "2", 39 | "Core Surge": "2", 40 | "Wild Strike": "2", 41 | "Evolve": "2", 42 | "Immolate": "2", 43 | "Havoc": "2", 44 | "Sentinel": "2", 45 | "Exhume": "2", 46 | "Blue Candle": "2", 47 | "Dead Branch": "2", 48 | "Singing Bowl": "2", 49 | "ForeignInfluence": "2", 50 | "Alpha": "2", 51 | "MentalFortress": "2", 52 | "Art of War": "2", 53 | "The Courier": "2", 54 | "Pandora\u0027s Box": "2", 55 | "CROW": "2", 56 | "SpiritShield": "2", 57 | "Wish": "2", 58 | "Wireheading": "2" 59 | } 60 | -------------------------------------------------------------------------------- /tests/test_screenshotting.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | import pytest 4 | 5 | from gym_sts.exceptions import StSTimeoutError 6 | 7 | 8 | def test_screenshot_regardless_of_animation(env, headless): 9 | if not headless: 10 | pytest.skip("Test can only run headless for now") 11 | 12 | filename = "test.png" 13 | filepath = env.output_dir / "screenshots" / filename 14 | 15 | for setting in [True, False]: 16 | env.set_animate(setting) 17 | 18 | # TODO confirm screenshot isn't black? 19 | assert not filepath.exists() 20 | env.screenshot(filename) 21 | assert filepath.exists() 22 | filepath.unlink() 23 | 24 | # Animation settings are restored after screenshot 25 | assert env.animate == setting 26 | 27 | 28 | def test_screenshot_on_step_error(env, headless): 29 | if not headless: 30 | pytest.skip("Test can only run headless for now") 31 | 32 | folder = env.output_dir / "screenshots" 33 | folder_contents = list(folder.iterdir()) 34 | assert len(folder_contents) == 0 35 | 36 | with patch.object( 37 | env.communicator, "_manual_command", side_effect=StSTimeoutError("testing") 38 | ): 39 | with pytest.raises(StSTimeoutError): 40 | env.step(1) 41 | folder_contents = list(folder.iterdir()) 42 | assert len(folder_contents) == 1 43 | folder_contents[0].unlink() 44 | -------------------------------------------------------------------------------- /tests/observations/serde/test_persistent.py: -------------------------------------------------------------------------------- 1 | from gym_sts.envs.base import SlayTheSpireGymEnv 2 | 3 | 4 | def test_persistent_serde(env: SlayTheSpireGymEnv): 5 | env.reset(seed=43) 6 | 7 | obs = env.observe(add_to_cache=True) 8 | orig_state = obs.persistent_state 9 | ser = orig_state.serialize() 10 | de = obs.persistent_state.deserialize(ser) 11 | 12 | assert orig_state == de 13 | 14 | # Edit health, max health, potions, relics, deck, and keys 15 | env.communicator.basemod("gold add 100") 16 | env.communicator.basemod("maxhp lose 7") 17 | env.communicator.basemod("hp lose 7") 18 | env.communicator.basemod("potions 0 Ambrosia") 19 | env.communicator.basemod("relic add Anchor") 20 | env.communicator.basemod("deck remove all") 21 | env.communicator.basemod("deck add Accuracy") 22 | env.communicator.basemod("key add ruby") 23 | 24 | obs = env.observe(add_to_cache=True) 25 | orig_state = obs.persistent_state 26 | ser = orig_state.serialize() 27 | de = obs.persistent_state.deserialize(ser) 28 | 29 | assert orig_state == de 30 | 31 | # Go to a different screen type (Neow is an event) 32 | env.step(3) 33 | env.step(4) 34 | env.step(3) 35 | env.step(3) 36 | 37 | obs = env.observe(add_to_cache=True) 38 | orig_state = obs.persistent_state 39 | ser = orig_state.serialize() 40 | de = obs.persistent_state.deserialize(ser) 41 | 42 | assert orig_state == de 43 | -------------------------------------------------------------------------------- /tests/env/test_single_combat_env.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from gym_sts.spaces.actions import PlayCard 4 | 5 | 6 | def test_single_combat_spawns_direct_into_combat(single_combat_env): 7 | single_combat_env.reset() 8 | obs = single_combat_env.observe() 9 | 10 | assert obs.in_combat 11 | assert obs.screen_type == "NONE" 12 | 13 | 14 | def test_single_combat_resets_after_defeating_enemy(single_combat_env): 15 | single_combat_env.reset(seed=42) 16 | 17 | single_combat_env.communicator.basemod("relic add NeowsBlessing") 18 | single_combat_env.communicator.basemod("fight Gremlin_Nob") 19 | time.sleep(0.1) 20 | single_combat_env.communicator.basemod("hand discard all") 21 | time.sleep(0.1) 22 | obs = single_combat_env.communicator.basemod("hand add Strike_B") 23 | time.sleep(0.1) 24 | 25 | # Current state should have exactly 1 Strike card in player's hand 26 | assert obs.combat_state.enemies[0].current_hp == 1 27 | assert len(obs.valid_actions) == 2 28 | assert isinstance(obs.valid_actions[1], PlayCard) 29 | 30 | # Playing a card doesn't immediately reduce the enemy's HP, it simply queues the 31 | # card to be played by the action manager 32 | _, _, should_reset, _ = single_combat_env.step(1) 33 | # End turn 34 | _, _, should_reset, info = single_combat_env.step(0) 35 | 36 | # End of combat, we should reset 37 | assert not info["observation"].in_combat 38 | assert should_reset 39 | -------------------------------------------------------------------------------- /gym_sts/data/state_logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | from pathlib import Path 4 | 5 | from gym_sts.spaces.actions import Action 6 | from gym_sts.spaces.observations import Observation 7 | 8 | 9 | class StateLogger: 10 | def __init__(self, logdir, batch_size: int = 100, indent: int | None = None): 11 | self.logdir = Path(logdir) 12 | self.batch_size = batch_size 13 | self.indent = indent 14 | 15 | # TODO: Set this to true in the RL runner, 16 | # or make it configurable (which requires more plumbing) 17 | self.write_wandb = False 18 | 19 | self.unlogged_actions: list[dict] = [] 20 | 21 | def log(self, action: Action | None, reward: float | None, after_obs: Observation): 22 | self.unlogged_actions.append( 23 | { 24 | "action": action.to_command() if action else None, 25 | "reward": reward, 26 | "state_after": after_obs.state, 27 | } 28 | ) 29 | 30 | if len(self.unlogged_actions) >= self.batch_size: 31 | self.flush_actions() 32 | 33 | def flush_actions(self): 34 | now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 35 | outpath = self.logdir / f"states_{now}.json" 36 | with open(outpath, "w") as f: 37 | f.write(json.dumps(self.unlogged_actions, indent=self.indent)) 38 | 39 | self.unlogged_actions = [] 40 | 41 | # TODO: Implement writing to WandB 42 | 43 | print("Actions logged to", outpath) 44 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from gym_sts.envs.base import SlayTheSpireGymEnv 4 | from gym_sts.envs.single_combat import SingleCombatSTSEnv 5 | 6 | 7 | def pytest_addoption(parser): 8 | parser.addoption("--lib-dir", default="lib", help="Location of the lib directory") 9 | parser.addoption( 10 | "--mods-dir", default="mods", help="Location of the mods directory" 11 | ) 12 | parser.addoption( 13 | "--headless", action="store_true", help="If provided, run without a visible UI" 14 | ) 15 | 16 | 17 | @pytest.fixture 18 | def headless(request): 19 | return request.config.getoption("headless") 20 | 21 | 22 | # Because starting the game is time-consuming, we scope this fixture to the entire 23 | # test session, i.e. the same game is used across all tests. This means each test 24 | # must call `env.reset()` in order to ensure isolation from previous tests. 25 | @pytest.fixture(scope="session") 26 | def env(request): 27 | lib_dir = request.config.getoption("lib_dir") 28 | mods_dir = request.config.getoption("mods_dir") 29 | headless = request.config.getoption("headless") 30 | 31 | env = SlayTheSpireGymEnv(lib_dir, mods_dir, headless=headless) 32 | yield env 33 | 34 | env.close() 35 | 36 | 37 | @pytest.fixture(scope="module") 38 | def single_combat_env(request): 39 | lib_dir = request.config.getoption("lib_dir") 40 | mods_dir = request.config.getoption("mods_dir") 41 | headless = request.config.getoption("headless") 42 | 43 | env = SingleCombatSTSEnv(lib_dir, mods_dir, headless=headless) 44 | yield env 45 | 46 | env.close() 47 | -------------------------------------------------------------------------------- /gym_sts/build/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:20.04 2 | 3 | WORKDIR /game 4 | 5 | RUN apt-get update && apt-get install -y \ 6 | alsa-utils \ 7 | htop `: # useful diagnostic tool` \ 8 | libopenal1 \ 9 | openjdk-8-jre `: # java` \ 10 | x11-xserver-utils \ 11 | xvfb `: # virtual X screen` \ 12 | scrot `: # utility to take screenshots` 13 | 14 | COPY preferences preferences 15 | 16 | COPY info.displayconfig info.displayconfig 17 | 18 | # Set default sound card to index 1, which is expected to be a 19 | # loopback sound card created by the host 20 | COPY asound.conf /etc/asound.conf 21 | 22 | # modthespire assumes you've installed the game through steam, so we have to trick it 23 | # into thinking nothing is amiss. 24 | # NB: WORKDIR statements create non-existent directories recursively 25 | WORKDIR /root/.steam/steam/steamapps 26 | RUN touch appmanifest_646570.acf # modthespire expects this file to be present 27 | WORKDIR common/SlayTheSpire 28 | RUN ln -s /game/lib/desktop-1.0.jar 29 | WORKDIR jre/bin 30 | RUN ln -s /etc/alternatives/java # modthespire uses the java JRE included with the game 31 | 32 | WORKDIR /root/.config/ModTheSpire/CommunicationMod 33 | COPY communication_mod.config.properties config.properties 34 | 35 | WORKDIR /root/.config/ModTheSpire/SuperFastMode 36 | COPY superfastmode.config.properties SuperFastModeConfig.properties 37 | 38 | WORKDIR /game 39 | COPY pipe_to_host.sh pipe_to_host.sh 40 | 41 | ENTRYPOINT ["xvfb-run", "-e", "/dev/stdout", "-f", "/tmp/sts.xauth", "-s", "-screen 0 1024x576x24"] 42 | CMD ["java", "-jar", "/game/lib/ModTheSpire.jar", "--skip-intro", "--mods", "basemod,CommunicationMod,superfastmode"] 43 | -------------------------------------------------------------------------------- /tests/env/test_rebooting.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | 4 | def test_reboot_occurs_every_reset(env): 5 | env.reset_count = 0 6 | env.reboot_frequency = 1 7 | 8 | with patch.object(env, "_end_game") as mock_end_game: 9 | env.reset() 10 | env.reset() 11 | env.reset() 12 | 13 | mock_end_game.assert_not_called() 14 | 15 | 16 | def test_reboot_occurs_only_once(env): 17 | env.reset_count = 0 18 | env.reboot_frequency = 0 19 | env.reset() 20 | 21 | with patch.object(env, "reboot") as mock_reboot: 22 | env.reset() 23 | env.reset() 24 | mock_reboot.assert_not_called() 25 | 26 | 27 | def test_reboot_occurs_every_other_time(env): 28 | env.reset_count = 0 29 | env.reboot_frequency = 2 30 | 31 | with patch.object(env, "_end_game", side_effect=env._end_game) as mock_end_game: 32 | with patch.object(env, "reboot", side_effect=env.reboot) as mock_reboot: 33 | for i in range(6): 34 | env.reset() 35 | if i % 2 == 0: 36 | mock_reboot.assert_called_once() 37 | else: 38 | mock_end_game.assert_called_once() 39 | 40 | mock_reboot.reset_mock() 41 | mock_end_game.reset_mock() 42 | 43 | 44 | def test_force_reboot(env): 45 | env.reset_count = 0 46 | env.reboot_frequency = 0 47 | env.reset() 48 | 49 | with patch.object(env, "reboot", side_effect=env.reboot) as mock_reboot: 50 | env.reset() 51 | mock_reboot.assert_not_called() 52 | 53 | env.reset(options={"reboot": True}) 54 | mock_reboot.assert_called_once() 55 | -------------------------------------------------------------------------------- /tests/observations/serde/types/test_rewards.py: -------------------------------------------------------------------------------- 1 | from hypothesis import given 2 | from hypothesis import strategies as st 3 | 4 | import gym_sts.spaces.constants.base as base_consts 5 | from gym_sts.spaces.observations import types 6 | 7 | from .test_potions import create_potion_base 8 | from .test_relics import create_relic_base 9 | 10 | 11 | @given(st.builds(types.GoldReward)) 12 | def test_gold_reward_serde(reward: types.GoldReward): 13 | ser = reward.serialize() 14 | de = reward.deserialize(ser) 15 | 16 | assert reward == de 17 | 18 | 19 | @st.composite 20 | def create_potion_reward(draw, potions=create_potion_base()): 21 | potion = draw(potions) 22 | return types.PotionReward(value=potion) 23 | 24 | 25 | @given(create_potion_reward()) 26 | def test_potion_reward_serde(reward: types.PotionReward): 27 | ser = reward.serialize() 28 | de = reward.deserialize(ser) 29 | 30 | assert reward == de 31 | 32 | 33 | @st.composite 34 | def create_relic_reward(draw, relics=create_relic_base()): 35 | relic = draw(relics) 36 | return types.RelicReward(value=relic) 37 | 38 | 39 | @given(create_relic_reward()) 40 | def test_relic_reward_serde(reward: types.RelicReward): 41 | ser = reward.serialize() 42 | de = reward.deserialize(ser) 43 | 44 | assert reward == de 45 | 46 | 47 | @given(st.builds(types.CardReward)) 48 | def test_card_reward_serde(reward: types.CardReward): 49 | ser = reward.serialize() 50 | de = reward.deserialize(ser) 51 | 52 | assert reward == de 53 | 54 | 55 | @given(st.builds(types.KeyReward, value=st.sampled_from(base_consts.ALL_KEYS))) 56 | def test_key_reward_serde(reward: types.KeyReward): 57 | ser = reward.serialize() 58 | de = reward.deserialize(ser) 59 | 60 | assert reward == de 61 | -------------------------------------------------------------------------------- /gym_sts/communication/receiver.py: -------------------------------------------------------------------------------- 1 | import fcntl 2 | import json 3 | import os 4 | import time 5 | 6 | from gym_sts import exceptions 7 | 8 | 9 | class Receiver: 10 | def __init__(self, fn, timeout: float = 50): 11 | self.fh = open(fn, "r") 12 | 13 | # Reading the pipe does not block if there are no contents 14 | flag = fcntl.fcntl(self.fh, fcntl.F_GETFD) 15 | fcntl.fcntl(self.fh, fcntl.F_SETFL, flag | os.O_NONBLOCK) 16 | 17 | self.timeout = timeout 18 | self.sleep_time = 0.05 19 | self.num_steps = int(timeout / self.sleep_time) 20 | 21 | def empty_fifo(self) -> None: 22 | """ 23 | Read and discard all pipe content. 24 | 25 | Typically the caller would do this to ensure that the next message on the fifo 26 | corresponds to the result of the next action sent to the game. 27 | """ 28 | 29 | self.fh.readlines() 30 | 31 | def receive_game_state(self) -> dict: 32 | """ 33 | Continues reading game state until the game is waiting for action from 34 | the agent 35 | """ 36 | for _ in range(self.num_steps): 37 | message = self.fh.readline() 38 | if len(message) > 0: 39 | try: 40 | state = json.loads(message) 41 | if state["ready_for_command"]: 42 | return state 43 | except json.decoder.JSONDecodeError: 44 | print( 45 | "W: Message not in valid JSON, retrying. Contents: " + message 46 | ) 47 | 48 | time.sleep(self.sleep_time) 49 | 50 | raise exceptions.StSTimeoutError( 51 | f"Waited {self.timeout} seconds for game state to be ready " 52 | "for command, but it didn't happen." 53 | ) 54 | -------------------------------------------------------------------------------- /tests/observations/serde/test_combat.py: -------------------------------------------------------------------------------- 1 | from gym_sts.envs.base import SlayTheSpireGymEnv 2 | 3 | 4 | def test_combat_serde(env: SlayTheSpireGymEnv): 5 | env.reset() 6 | 7 | # Enter combat 8 | env.communicator.basemod("fight 2_Orb_Walkers") 9 | 10 | obs = env.observe(add_to_cache=True) 11 | orig_state = obs.combat_state 12 | ser = orig_state.serialize() 13 | de = obs.combat_state.deserialize(ser) 14 | 15 | assert orig_state == de 16 | 17 | 18 | def test_dead_minions_dont_overflow_serde_bounds(env: SlayTheSpireGymEnv): 19 | """ 20 | Combats with large numbers of minions shouldn't cause exceptions during serde. 21 | 22 | There should only be up to 6 enemies on-screen at a time, but 23 | CommunicationMod continues to send data about dead enemies/minions. 24 | Confirm that we filter out the dead enemies/minions to guarantee we don't 25 | overflow the bounds of the serialization representation. 26 | """ 27 | 28 | env.reset(seed=42) 29 | 30 | # Enter combat 31 | env.communicator.basemod("fight Reptomancer") 32 | 33 | # Add enough HP that we can just wait for tons of minions to spawn and kamikaze 34 | env.communicator.basemod("maxhp add 900") 35 | 36 | for _ in range(22): 37 | obs = env.observe(add_to_cache=True) 38 | orig_state = obs.combat_state 39 | ser = orig_state.serialize() 40 | de = obs.combat_state.deserialize(ser) 41 | 42 | assert orig_state.enemies == de.enemies 43 | 44 | # End turn 45 | env.step(0) 46 | 47 | # Confirm that attacking a specific enemy index works as expected. 48 | # We attack one of Reptomancer's daggers and confirm that its health decreases. 49 | env.step(75) 50 | obs = env.observe(add_to_cache=True) 51 | enemies = obs.combat_state.enemies 52 | assert enemies[1].current_hp == 16 53 | -------------------------------------------------------------------------------- /gym_sts/spaces/constants/events.py: -------------------------------------------------------------------------------- 1 | ALL_EVENTS = [ 2 | "NONE", # Indicates the absence of an event 3 | "Shining Light", 4 | "World of Goop", 5 | "Mushrooms", 6 | "The Cleric", 7 | "Dead Adventurer", 8 | "Living Wall", 9 | "Big Fish", 10 | "Liars Game", 11 | "Scrap Ooze", 12 | "Golden Wing", 13 | "Golden Idol", 14 | "Beggar", 15 | "Colosseum", 16 | "The Mausoleum", 17 | "The Library", 18 | "Addict", 19 | "Cursed Tome", 20 | "The Joust", 21 | "Forgotten Altar", 22 | "Masked Bandits", 23 | "Drug Dealer", 24 | "Knowing Skull", 25 | "Back to Basics", 26 | "Vampires", 27 | "Nest", 28 | "Ghosts", 29 | "Mysterious Sphere", 30 | "Tomb of Lord Red Mask", 31 | "SecretPortal", 32 | "The Moai Head", 33 | "Spire Heart", 34 | "SensoryStone", 35 | "MindBloom", 36 | "Falling", 37 | "Winding Halls", 38 | "Golden Shrine", 39 | "Accursed Blacksmith", 40 | "Designer", 41 | "Fountain of Cleansing", 42 | "Wheel of Change", 43 | "Duplicator", 44 | "The Woman in Blue", 45 | "Match and Keep!", 46 | "NoteForYourself", 47 | "WeMeetAgain", 48 | "Transmorgrifier", 49 | "N'loth", 50 | "Bonfire Elementals", 51 | "Purifier", 52 | "Upgrade Shrine", 53 | "Lab", 54 | "FaceTrader", 55 | "Neow Event", 56 | ] 57 | NUM_EVENTS = len(ALL_EVENTS) 58 | 59 | # Most (if not all) numbers in random events occur in the middle of the text, so spaces 60 | # are enough to match the left word boundary 61 | # It's a feature that the right word boundary is intentionally omitted so e.g. "1" and 62 | # "10" still match "100%" 63 | GLOBALLY_CHECKED_TEXTS = [f" {x}" for x in range(100)] 64 | 65 | NUM_GLOBALLY_CHECKED_TEXTS = len(GLOBALLY_CHECKED_TEXTS) 66 | MAX_NUM_CUSTOM_TEXTS = 30 67 | MAX_NUM_TEXTS = MAX_NUM_CUSTOM_TEXTS + NUM_GLOBALLY_CHECKED_TEXTS 68 | -------------------------------------------------------------------------------- /tests/observations/serde/types/test_base.py: -------------------------------------------------------------------------------- 1 | from hypothesis import given 2 | from hypothesis import strategies as st 3 | 4 | import gym_sts.spaces.constants.combat as combat_consts 5 | from gym_sts.spaces.observations.types.base import ( 6 | Effect, 7 | Enemy, 8 | Health, 9 | Keys, 10 | Orb, 11 | ShopMixin, 12 | ) 13 | 14 | 15 | @given(st.builds(Effect, id=st.sampled_from(combat_consts.ALL_EFFECTS))) 16 | def test_effect_serde(effect: Effect): 17 | ser = effect.serialize() 18 | de = effect.deserialize(ser) 19 | 20 | # Note that the effect ID isn't serialized, so it's not recovered 21 | # TODO serialize ID as well and then discard if unwanted 22 | assert effect.amount == de.amount 23 | 24 | 25 | @given( 26 | st.builds( 27 | Enemy, 28 | id=st.sampled_from(combat_consts.ALL_MONSTER_TYPES), 29 | intent=st.sampled_from(combat_consts.ALL_INTENTS), 30 | ) 31 | ) 32 | def test_enemy_serde(enemy: Enemy): 33 | ser = enemy.serialize() 34 | de = enemy.deserialize(ser) 35 | 36 | assert enemy == de 37 | 38 | 39 | @given(st.builds(Health)) 40 | def test_health_serde(health: Health): 41 | ser = health.serialize() 42 | de = health.deserialize(ser) 43 | 44 | assert health == de 45 | 46 | 47 | @given(st.builds(Keys)) 48 | def test_keys_serde(keys: Keys): 49 | ser = keys.serialize() 50 | de = Keys.deserialize(ser) 51 | assert keys == de 52 | 53 | 54 | @given(st.builds(Orb, id=st.sampled_from(combat_consts.ALL_ORBS))) 55 | def test_orb_serde(orb: Orb): 56 | ser = orb.serialize() 57 | de = orb.deserialize(ser) 58 | 59 | assert orb == de 60 | 61 | 62 | @given(st.builds(ShopMixin)) 63 | def test_shop_mixin_serde(mixin: ShopMixin): 64 | ser = mixin.serialize_price() 65 | de = ShopMixin.deserialize_price(ser) 66 | 67 | assert mixin.price == de 68 | 69 | 70 | def test_shop_mixin_serde_empty(): 71 | ser = ShopMixin.serialize_empty_price() 72 | de = ShopMixin.deserialize_price(ser) 73 | 74 | assert de == 0 75 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/components/campfire.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union 4 | 5 | from gymnasium.spaces import Dict, Discrete, Space, Tuple 6 | from pydantic import BaseModel, Field 7 | 8 | import gym_sts.spaces.constants.campfire as campfire_consts 9 | from gym_sts.spaces.observations import types 10 | 11 | from .base import PydanticComponent 12 | 13 | 14 | class CampfireObs(PydanticComponent): 15 | options: list[types.CampfireChoice] = Field([], alias="rest_options") 16 | has_rested: bool = False 17 | 18 | @staticmethod 19 | def space() -> Space: 20 | return Dict( 21 | { 22 | "options": Tuple( 23 | [types.CampfireChoice.space()] * campfire_consts.MAX_NUM_OPTIONS 24 | ), 25 | "has_rested": Discrete(2), 26 | } 27 | ) 28 | 29 | def serialize(self) -> dict: 30 | options = [ 31 | types.CampfireChoice.serialize_empty() 32 | ] * campfire_consts.MAX_NUM_OPTIONS 33 | for i, option in enumerate(self.options): 34 | options[i] = option.serialize() 35 | 36 | return { 37 | "options": options, 38 | "has_rested": int(self.has_rested), 39 | } 40 | 41 | class SerializedState(BaseModel): 42 | options: list[types.BinaryArray] 43 | has_rested: int 44 | 45 | class Config: 46 | arbitrary_types_allowed = True 47 | 48 | @classmethod 49 | def deserialize(cls, data: Union[dict, SerializedState]) -> CampfireObs: 50 | if not isinstance(data, cls.SerializedState): 51 | data = cls.SerializedState(**data) 52 | 53 | options = [] 54 | for o in data.options: 55 | option = types.CampfireChoice.deserialize(o) 56 | if option != types.CampfireChoice.EMPTY: 57 | options.append(option) 58 | 59 | has_rested = bool(data.has_rested) 60 | 61 | return cls(rest_options=options, has_rested=has_rested) 62 | -------------------------------------------------------------------------------- /gym_sts/envs/single_combat.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List 2 | 3 | from gym_sts.spaces.observations import Observation 4 | 5 | from .base import SlayTheSpireGymEnv 6 | from .utils import single_combat_value 7 | 8 | 9 | class SingleCombatSTSEnv(SlayTheSpireGymEnv): 10 | def __init__( 11 | self, 12 | *args, 13 | value_fn: Callable[[Observation], float] = single_combat_value, 14 | enemies: List[str] = ["3_Sentries"], 15 | cards: List[str], 16 | add_relics: List[str], 17 | **kwargs, 18 | ): 19 | super().__init__(*args, value_fn=value_fn, **kwargs) # type: ignore[misc] 20 | self.enemies = enemies 21 | self.cards = cards 22 | self.add_relics = add_relics 23 | 24 | def reset(self, *args, **kwargs): 25 | super().reset(*args, **kwargs) # type: ignore[misc] 26 | 27 | # prng should have already been set in super().reset 28 | assert self.prng is not None 29 | 30 | self.communicator.basemod("deck remove all") 31 | 32 | for card in self.cards: 33 | self.communicator.basemod(f"deck add {card}") 34 | 35 | for relic in self.add_relics: 36 | self.communicator.basemod(f"relic add {relic}") 37 | 38 | enemy = self.prng.choice(self.enemies) 39 | obs = self.communicator.basemod(f"fight {enemy}") 40 | 41 | assert obs.in_combat 42 | self.observation_cache.append(obs) 43 | 44 | info = { 45 | "seed": self.seed, 46 | "sts_seed": self.sts_seed, 47 | "rng_state": self.prng.getstate(), 48 | "observation": obs, 49 | } 50 | return obs.serialize(), info 51 | 52 | def step(self, action_id: int): 53 | ser, reward, should_reset, truncated, info = super().step(action_id) 54 | 55 | # When you test out a new combat, make sure this condition works 56 | if not info["observation"].in_combat: 57 | should_reset = True 58 | 59 | return ser, reward, should_reset, truncated, info 60 | -------------------------------------------------------------------------------- /gym_sts/communication/sender.py: -------------------------------------------------------------------------------- 1 | class Sender: 2 | def __init__(self, fn): 3 | self.fh = open(fn, "w") 4 | 5 | def _send_message(self, msg: str) -> None: 6 | self.fh.write(f"{msg}\n") 7 | self.fh.flush() 8 | 9 | def send_ready(self) -> None: 10 | self._send_message("READY") 11 | 12 | def send_start(self, player_class: str, ascension: int, seed: str) -> None: 13 | self._send_message(f"START {player_class} {ascension} {seed}") 14 | 15 | def send_proceed(self) -> None: 16 | self._send_message("PROCEED") 17 | 18 | def send_choose(self, choice) -> None: 19 | self._send_message(f"CHOOSE {choice}") 20 | 21 | def send_click(self, x: int, y: int, left: bool = True) -> None: 22 | side = "left" if left else "right" 23 | self._send_message(f"CLICK {side} {x} {y}") 24 | 25 | def send_play(self, index, target) -> None: 26 | # NOTE: Card index argument is indexed from 1, with 0 representing position 10. 27 | # Indices can change in the middle of a game. 28 | # Target argument is indexed from 0. 29 | self._send_message(f"PLAY {index} {target}") 30 | 31 | def send_end(self) -> None: 32 | self._send_message("END") 33 | 34 | def send_potion(self, action, slot, target) -> None: 35 | self._send_message(f"POTION {action} {slot} {target}") 36 | 37 | def send_resign(self) -> None: 38 | self._send_message("RESIGN") 39 | 40 | def send_wait(self, frames: int) -> None: 41 | self._send_message(f"WAIT {frames}") 42 | 43 | def send_state(self) -> None: 44 | """ 45 | Get the JSON representation of the current game state, regardless of whether or 46 | not the game is "stable." This method is valid in all game states. 47 | """ 48 | 49 | self._send_message("STATE") 50 | 51 | def send_basemod(self, command: str) -> None: 52 | self._send_message(f"BASEMOD {command}") 53 | 54 | def send_render(self, render: bool) -> None: 55 | self._send_message(f"RENDER {render}") 56 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/components/event.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gymnasium.spaces import Dict, Discrete, MultiBinary 3 | 4 | import gym_sts.spaces.constants.events as event_consts 5 | from gym_sts.spaces.data import EVENT_DATA 6 | 7 | from .base import ObsComponent 8 | 9 | 10 | # Processing text is annoying, especially when text is the only real indicator for event 11 | # state and the numbers can sometimes change. 12 | # So, each text fragment (extracted from game files) has an individual bit flag that 13 | # indicates whether the given text is present in the current event. 14 | 15 | 16 | class EventStateObs(ObsComponent): 17 | def __init__(self, state: dict): 18 | self.event_id = "NONE" 19 | self.raw_text = "" 20 | self.text_matches = [] 21 | 22 | if "game_state" in state: 23 | game_state = state["game_state"] 24 | if "screen_type" in game_state and game_state["screen_type"] == "EVENT": 25 | screen_state = game_state["screen_state"] 26 | self.event_id = screen_state["event_id"] 27 | self.find_raw_text(screen_state) 28 | 29 | self.text_matches = EVENT_DATA.find_matches( 30 | self.event_id, self.raw_text 31 | ) 32 | 33 | def find_raw_text(self, screen_state): 34 | texts = [screen_state["body_text"]] 35 | 36 | for opt in screen_state["options"]: 37 | texts.append(opt["text"]) 38 | 39 | self.raw_text = "".join(texts) 40 | 41 | @staticmethod 42 | def space(): 43 | return Dict( 44 | { 45 | "event_id": Discrete(event_consts.NUM_EVENTS), 46 | "text": MultiBinary(event_consts.MAX_NUM_TEXTS), 47 | } 48 | ) 49 | 50 | def serialize(self) -> dict: 51 | text = [flag for _, flag in self.text_matches] 52 | text.extend([False] * (event_consts.MAX_NUM_TEXTS - len(text))) 53 | return { 54 | "event_id": event_consts.ALL_EVENTS.index(self.event_id), 55 | "text": np.array(text), 56 | } 57 | -------------------------------------------------------------------------------- /tests/observations/serde/test_campfire.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from gym_sts.envs.base import SlayTheSpireGymEnv 4 | from gym_sts.spaces.observations.components import CampfireObs 5 | from gym_sts.spaces.observations.types import CampfireChoice 6 | 7 | 8 | def test_campfire_serde(env: SlayTheSpireGymEnv): 9 | env.reset(seed=42) 10 | 11 | # Neow event 12 | env.step(3) 13 | env.step(4) 14 | env.step(3) 15 | 16 | # First combat 17 | env.step(6) 18 | time.sleep(0.5) # Wait briefly for card adding animation to complete 19 | env.communicator.basemod("kill all") 20 | env.step(2) 21 | 22 | # Event 23 | env.step(3) 24 | env.step(4) 25 | env.step(3) 26 | 27 | # Event 28 | env.step(3) 29 | env.step(3) 30 | env.step(3) 31 | 32 | # Event 33 | env.step(3) 34 | env.step(4) 35 | env.step(3) 36 | env.step(2) 37 | env.step(3) 38 | 39 | # Combat 40 | env.step(3) 41 | time.sleep(0.5) # Wait briefly for card adding animation to complete 42 | env.communicator.basemod("kill all") 43 | env.step(2) 44 | 45 | # Finally, a campfire 46 | env.step(3) 47 | 48 | obs = env.observe(add_to_cache=True) 49 | orig_state = obs.campfire_state 50 | ser = orig_state.serialize() 51 | de = obs.campfire_state.deserialize(ser) 52 | 53 | assert orig_state == de 54 | 55 | 56 | def test_rest_option_order_matters(): 57 | dig_then_toke = [ 58 | CampfireChoice.REST, 59 | CampfireChoice.SMITH, 60 | CampfireChoice.DIG, 61 | CampfireChoice.TOKE, 62 | CampfireChoice.RECALL, 63 | ] 64 | 65 | toke_then_dig = [ 66 | CampfireChoice.REST, 67 | CampfireChoice.SMITH, 68 | CampfireChoice.TOKE, 69 | CampfireChoice.DIG, 70 | CampfireChoice.RECALL, 71 | ] 72 | 73 | obs1 = CampfireObs(has_rested=False, rest_options=dig_then_toke) 74 | obs2 = CampfireObs(has_rested=False, rest_options=toke_then_dig) 75 | 76 | assert obs1 != obs2 77 | 78 | ser1 = obs1.serialize() 79 | ser2 = obs2.serialize() 80 | de1 = CampfireObs.deserialize(ser1) 81 | de2 = CampfireObs.deserialize(ser2) 82 | 83 | assert obs1 == de1 84 | assert obs2 == de2 85 | assert de1 != de2 86 | -------------------------------------------------------------------------------- /tests/observations/serde/types/test_relics.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from hypothesis import given 4 | from hypothesis import strategies as st 5 | 6 | from gym_sts.spaces.constants.relics import RelicCatalog 7 | from gym_sts.spaces.observations.types import Relic, RelicBase, ShopRelic 8 | 9 | 10 | @st.composite 11 | def create_relic_base(draw, relic_ids=st.sampled_from(RelicCatalog.ids)): 12 | relic_id = draw(relic_ids) 13 | relic_metadata = getattr(RelicCatalog, relic_id) 14 | 15 | return draw( 16 | st.builds( 17 | RelicBase, 18 | id=st.just(relic_id), 19 | name=st.just(relic_metadata.name), 20 | ) 21 | ) 22 | 23 | 24 | @given(create_relic_base()) 25 | def test_relic_base_serde_binary(relic: RelicBase): 26 | ser = relic.serialize() 27 | de = relic.deserialize(ser) 28 | 29 | assert relic == de 30 | 31 | 32 | @given(create_relic_base()) 33 | def test_relic_base_serde_discrete(relic: RelicBase): 34 | ser = relic.serialize(discrete=True) 35 | de = relic.deserialize(ser) 36 | 37 | assert relic == de 38 | 39 | 40 | def test_relic_base_serde_empty_binary(): 41 | ser = RelicBase.serialize_empty() 42 | de = RelicBase.deserialize(ser) 43 | 44 | assert de.id == RelicCatalog.NONE.id 45 | 46 | 47 | def test_relic_base_serde_empty_discrete(): 48 | ser = RelicBase.serialize_empty(discrete=True) 49 | de = RelicBase.deserialize(ser) 50 | 51 | assert de.id == RelicCatalog.NONE.id 52 | 53 | 54 | def create_relic_subclass(Model: Union[type[Relic], type[ShopRelic]]): 55 | @st.composite 56 | def create_subclass(draw, relic_bases=create_relic_base()): 57 | relic_base = draw(relic_bases) 58 | 59 | return draw( 60 | st.builds( 61 | Model, 62 | id=st.just(relic_base.id), 63 | name=st.just(relic_base.name), 64 | ) 65 | ) 66 | 67 | return create_subclass() 68 | 69 | 70 | @given(create_relic_subclass(Relic)) 71 | def test_relic_serde(relic: Relic): 72 | ser = relic.serialize() 73 | de = relic.deserialize(ser) 74 | 75 | assert relic == de 76 | 77 | 78 | def test_relic_serde_empty(): 79 | ser = Relic.serialize_empty() 80 | de = Relic.deserialize(ser) 81 | 82 | assert de.id == RelicCatalog.NONE.id 83 | 84 | 85 | @given(create_relic_subclass(ShopRelic)) 86 | def test_shop_relic_serde(relic: ShopRelic): 87 | ser = relic.serialize() 88 | de = relic.deserialize(ser) 89 | 90 | assert relic == de 91 | 92 | 93 | def test_shop_relic_serde_empty(): 94 | ser = ShopRelic.serialize_empty() 95 | de = ShopRelic.deserialize(ser) 96 | 97 | assert de.id == RelicCatalog.NONE.id 98 | assert de.price == 0 99 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/components/card_reward.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union 4 | 5 | from gymnasium.spaces import Dict, Discrete, MultiBinary, Tuple 6 | from pydantic import BaseModel, Field 7 | 8 | import gym_sts.spaces.constants.cards as card_consts 9 | import gym_sts.spaces.constants.rewards as reward_consts 10 | from gym_sts.spaces.constants.cards import CardCatalog 11 | from gym_sts.spaces.observations import types 12 | 13 | from .base import PydanticComponent 14 | 15 | 16 | class CardRewardObs(PydanticComponent): 17 | cards: list[types.Card] = [] 18 | singing_bowl: bool = Field(False, alias="bowl_available") 19 | skippable: bool = Field(False, alias="skip_available") 20 | 21 | @staticmethod 22 | def space(): 23 | return Dict( 24 | { 25 | # At most 4 cards may be offered (due to Question Card relic). 26 | "cards": Tuple( 27 | ( 28 | MultiBinary(card_consts.LOG_NUM_CARDS_WITH_UPGRADES), 29 | MultiBinary(card_consts.LOG_NUM_CARDS_WITH_UPGRADES), 30 | MultiBinary(card_consts.LOG_NUM_CARDS_WITH_UPGRADES), 31 | MultiBinary(card_consts.LOG_NUM_CARDS_WITH_UPGRADES), 32 | ) 33 | ), 34 | "singing_bowl": Discrete(2), 35 | "skippable": Discrete(2), 36 | } 37 | ) 38 | 39 | def serialize(self) -> dict: 40 | serialized_cards = [ 41 | types.Card.serialize_empty() 42 | ] * reward_consts.REWARD_CARD_COUNT 43 | for i, card in enumerate(self.cards): 44 | serialized_cards[i] = card.serialize() 45 | 46 | return { 47 | "cards": serialized_cards, 48 | "singing_bowl": int(self.singing_bowl), 49 | "skippable": int(self.skippable), 50 | } 51 | 52 | class SerializedState(BaseModel): 53 | cards: list[types.BinaryArray] 54 | singing_bowl: int 55 | skippable: int 56 | 57 | class Config: 58 | arbitrary_types_allowed = True 59 | 60 | @classmethod 61 | def deserialize(cls, data: Union[dict, SerializedState]) -> CardRewardObs: 62 | if not isinstance(data, cls.SerializedState): 63 | data = cls.SerializedState(**data) 64 | 65 | cards = [] 66 | for c in data.cards: 67 | card = types.Card.deserialize(c) 68 | if card.id != CardCatalog.NONE.id: 69 | cards.append(card) 70 | 71 | singing_bowl = bool(data.singing_bowl) 72 | skippable = bool(data.skippable) 73 | 74 | return cls(cards=cards, bowl_available=singing_bowl, skip_available=skippable) 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | gym-sts 2 | === 3 | 4 | An OpenAI Gym env for Slay the Spire. 5 | 6 | All installation commands are expected to be run as root. 7 | 8 | # Requirements 9 | 10 | - Java JDK 8 11 | - Docker 12 | - A copy of Slay the Spire, particularly its `desktop-1.0.jar` file. 13 | - The JAR files for several mods: 14 | - ModTheSpire 15 | - BaseMod 16 | - CommunicationMod 17 | - SuperFastMode 18 | 19 | This package has been tested with Python 3.9, but more recent Pythons may also work 20 | as long as you can build dependencies (there may not be prebuilt wheels). 21 | 22 | # Installation 23 | 24 | IMPORTANT: These instructions assume you're a developer of this library, not that you're 25 | trying to add this library as a dependency for another project. 26 | 27 | 1. Install the package along with its dependencies: 28 | 29 | ```zsh 30 | pip install -e . 31 | 32 | # Alternatively, you can use poetry 33 | poetry install 34 | ``` 35 | 36 | 2. Pull in required jar files. Directory structure should look like this: 37 | 38 | ``` 39 | gym-sts/ 40 | gym_sts/ 41 | ... 42 | lib/ 43 | desktop-1.0.jar 44 | ModTheSpire.jar 45 | mods/ 46 | BaseMod.jar 47 | CommunicationMod.jar 48 | SuperFastMode.jar 49 | ``` 50 | 51 | # Build 52 | 53 | To build the Docker container: 54 | 55 | ```python 56 | from gym_sts.envs.base import SlayTheSpireGymEnv 57 | 58 | SlayTheSpireGymEnv.build_image() 59 | ``` 60 | 61 | # Setup 62 | 63 | ```zsh 64 | # On the host, make sure a loopback sound card has been created 65 | modprobe snd-aloop # TODO can an index be assigned? 66 | ``` 67 | 68 | # Run 69 | 70 | ## Run the game headless in a Docker container (preferred) 71 | 72 | The Python script will start the container. You can communicate with the game using 73 | CommunicationMod commands via stdin. 74 | 75 | ``` 76 | python3 -m gym_sts.runner [lib_dir] [mod_dir] [out_dir] --headless 77 | ``` 78 | 79 | ## Run the game directly on the host 80 | 81 | The Python script will start the game as a subprocess. This allows for easy observation 82 | of gameplay. You can communicate with the game using CommunicationMod commands via 83 | stdin. 84 | 85 | 86 | ``` 87 | python3 -m gym_sts.runner [lib_dir] [mod_dir] [out_dir] 88 | ``` 89 | 90 | # Developer guide 91 | 92 | 1. After cloning the project, use [poetry](python-poetry.org/) to install dependencies 93 | in a virtual environment: 94 | 95 | ```zsh 96 | poetry install 97 | 98 | # Enter the venv in a subshell 99 | poetry shell 100 | ``` 101 | 102 | This project uses [pre-commit](https://pre-commit.com/) to configure various 103 | linters/fixers. Make sure you've installed the pre-commit hook: 104 | 105 | ```zsh 106 | # Confirm pre-commit is installed 107 | pre-commit -V 108 | 109 | # Install the hook 110 | pre-commit install 111 | 112 | # Optionally run the linters manually 113 | pre-commit run --all-files 114 | ``` 115 | 116 | ## Tests 117 | 118 | Simply run `pytest` from the project's root directory. 119 | -------------------------------------------------------------------------------- /tests/observations/serde/types/test_potions.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from hypothesis import given 4 | from hypothesis import strategies as st 5 | 6 | from gym_sts.spaces.constants.potions import PotionCatalog 7 | from gym_sts.spaces.observations.types import Potion, PotionBase, ShopPotion 8 | 9 | 10 | @st.composite 11 | def create_potion_base(draw, potion_id=st.sampled_from(PotionCatalog.ids)): 12 | potion_id = draw(potion_id) 13 | potion_metadata = getattr(PotionCatalog, potion_id) 14 | 15 | return draw( 16 | st.builds( 17 | PotionBase, 18 | id=st.just(potion_id), 19 | name=st.just(potion_metadata.name), 20 | requires_target=st.just(potion_metadata.requires_target), 21 | ) 22 | ) 23 | 24 | 25 | @given(create_potion_base()) 26 | def test_potion_base_serde_binary(potion: PotionBase): 27 | ser = potion.serialize() 28 | de = potion.deserialize(ser) 29 | 30 | assert potion == de 31 | 32 | 33 | @given(create_potion_base()) 34 | def test_potion_base_serde_discrete(potion: PotionBase): 35 | ser = potion.serialize(discrete=True) 36 | de = potion.deserialize(ser) 37 | 38 | assert potion == de 39 | 40 | 41 | def test_potion_base_serde_empty_binary(): 42 | ser = PotionBase.serialize_empty() 43 | de = PotionBase.deserialize(ser) 44 | 45 | assert de.id == PotionCatalog.NONE.id 46 | 47 | 48 | def test_potion_base_serde_empty_discrete(): 49 | ser = PotionBase.serialize_empty(discrete=True) 50 | de = PotionBase.deserialize(ser) 51 | 52 | assert de.id == PotionCatalog.NONE.id 53 | 54 | 55 | def create_potion_subclass(Model: Union[type[Potion], type[ShopPotion]]): 56 | @st.composite 57 | def create_subclass(draw, potion_bases=create_potion_base()): 58 | potion_base = draw(potion_bases) 59 | 60 | return draw( 61 | st.builds( 62 | Model, 63 | id=st.just(potion_base.id), 64 | name=st.just(potion_base.name), 65 | requires_target=st.just(potion_base.requires_target), 66 | ) 67 | ) 68 | 69 | return create_subclass() 70 | 71 | 72 | @given(create_potion_subclass(Potion)) 73 | def test_potion_serde(potion: Potion): 74 | ser = potion.serialize() 75 | de = potion.deserialize(ser) 76 | 77 | assert potion == de 78 | 79 | 80 | def test_potion_serde_empty(): 81 | ser = Potion.serialize_empty() 82 | de = Potion.deserialize(ser) 83 | 84 | assert de.id == PotionCatalog.NONE.id 85 | 86 | 87 | @given(create_potion_subclass(ShopPotion)) 88 | def test_shop_potion_serde(potion: ShopPotion): 89 | ser = potion.serialize() 90 | de = potion.deserialize(ser) 91 | 92 | assert potion == de 93 | 94 | 95 | def test_shop_potion_serde_empty(): 96 | ser = ShopPotion.serialize_empty() 97 | de = ShopPotion.deserialize(ser) 98 | 99 | assert de.id == PotionCatalog.NONE.id 100 | assert de.price == 0 101 | -------------------------------------------------------------------------------- /gym_sts/envs/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import typing as tp 3 | from typing import Optional 4 | 5 | from gym_sts.spaces.observations import Observation 6 | 7 | 8 | class SeedHelpers: 9 | char_set = "0123456789ABCDEFGHIJKLMNPQRSTUVWXYZ" # Note no O 10 | 11 | @classmethod 12 | def make_seed_str(cls, seed_long: int) -> str: 13 | """ 14 | Based on code from com/megacrit/cardcrawl/helpers/SeedHelper.java 15 | """ 16 | 17 | base = len(cls.char_set) 18 | 19 | seed_str = "" 20 | while seed_long != 0: 21 | seed_long, remainder = divmod(seed_long, base) 22 | char = cls.char_set[remainder] 23 | seed_str = char + seed_str 24 | 25 | return seed_str 26 | 27 | @classmethod 28 | def make_seed(cls, rng: random.Random) -> str: 29 | unsigned_long = 2**64 30 | seed_long = rng.randrange(unsigned_long) 31 | 32 | return cls.make_seed_str(seed_long) 33 | 34 | @classmethod 35 | def validate_seed(cls, seed: str) -> str: 36 | """ 37 | Returns the seed if it's valid, raises a ValueError otherwise. 38 | """ 39 | 40 | seed = seed.upper() 41 | for char in seed: 42 | if char not in cls.char_set: 43 | raise ValueError(f"Seed contains illegal character '{char}'") 44 | 45 | return seed 46 | 47 | 48 | T = tp.TypeVar("T") 49 | 50 | 51 | class Cache(tp.Generic[T]): 52 | def __init__(self, size: int = 10): 53 | self.size = size 54 | self.index = 0 55 | self.cache: list[Optional[T]] = [None] * self.size 56 | 57 | def append(self, obs: T): 58 | self.cache[self.index] = obs 59 | self.index = (self.index + 1) % self.size 60 | 61 | def get(self, ago: int = 0) -> Optional[T]: 62 | """ 63 | Args: 64 | ago: The number of items back to retrieve (zero indexed). 65 | The value must be less than the cache size. 66 | """ 67 | 68 | if ago >= self.size: 69 | raise ValueError(f"ago must be less than the cache size ({self.size})") 70 | 71 | index = (self.index - ago - 1) % self.size 72 | return self.cache[index] 73 | 74 | def reset(self) -> None: 75 | self.cache = [None] * self.size 76 | 77 | 78 | def single_combat_value(obs: Observation) -> float: 79 | max_hp = sum(e.max_hp for e in obs.combat_state.enemies) 80 | enemy_hp = sum(e.current_hp for e in obs.combat_state.enemies) 81 | 82 | self_hp = obs.persistent_state.hp 83 | self_max_hp = obs.persistent_state.max_hp 84 | 85 | if max_hp == 0: 86 | enemy_hp = 0 87 | max_hp = 1 88 | 89 | p_damage = (max_hp - enemy_hp) / max_hp 90 | p_hp = self_hp / self_max_hp 91 | 92 | return (p_hp + 0.01) * p_damage 93 | 94 | 95 | def full_game_obs_value(obs: Observation) -> float: 96 | total = float(obs.persistent_state.floor) 97 | if obs.in_combat: 98 | total += single_combat_value(obs) 99 | return total 100 | -------------------------------------------------------------------------------- /gym_sts/spaces/actions.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional 3 | 4 | from gymnasium.spaces import Discrete 5 | from pydantic import BaseModel, PrivateAttr 6 | 7 | import gym_sts.spaces.constants.base as base_consts 8 | from gym_sts.spaces.constants import combat as combat_consts 9 | from gym_sts.spaces.constants import potions as potion_consts 10 | 11 | 12 | class Action(BaseModel, ABC): 13 | _id: int = PrivateAttr(-1) 14 | 15 | @abstractmethod 16 | def to_command(self) -> str: 17 | raise RuntimeError("not implemented") 18 | 19 | class Config: 20 | # Allows model instances to be hashable, e.g. they can be added to sets. 21 | # See https://pydantic-docs.helpmanual.io/usage/model_config/ 22 | frozen = True 23 | 24 | 25 | class PickCard(Action): 26 | card_id: int 27 | upgraded: bool 28 | 29 | def to_command(self): 30 | # TODO: Finish this 31 | raise RuntimeError("not implemented") 32 | 33 | 34 | class PlayCard(Action): 35 | # NOTE: Card position starts with 1 36 | card_position: int 37 | target_index: Optional[int] 38 | 39 | def to_command(self): 40 | target = "" if self.target_index is None else self.target_index 41 | return f"PLAY {self.card_position} {target}" 42 | 43 | 44 | class PotionAction(Action): 45 | potion_index: int 46 | 47 | 48 | class UsePotion(PotionAction): 49 | target_index: Optional[int] = None 50 | 51 | def to_command(self): 52 | target = "" if self.target_index is None else self.target_index 53 | return f"POTION USE {self.potion_index} {target}" 54 | 55 | 56 | class DiscardPotion(PotionAction): 57 | def to_command(self): 58 | return f"POTION DISCARD {self.potion_index}" 59 | 60 | 61 | class Choose(Action): 62 | choice_index: int 63 | 64 | def to_command(self): 65 | return f"CHOOSE {self.choice_index}" 66 | 67 | 68 | class EndTurn(Action): 69 | def to_command(self): 70 | return "END" 71 | 72 | 73 | class Return(Action): 74 | def to_command(self): 75 | return "RETURN" 76 | 77 | 78 | class Proceed(Action): 79 | def to_command(self): 80 | return "PROCEED" 81 | 82 | 83 | def all_actions() -> list[Action]: 84 | actions = [EndTurn(), Return(), Proceed()] 85 | 86 | for i in range(base_consts.NUM_CHOICES): 87 | actions.append(Choose(choice_index=i)) 88 | 89 | for i in range(potion_consts.NUM_POTION_SLOTS): 90 | actions.append(UsePotion(potion_index=i)) 91 | actions.append(DiscardPotion(potion_index=i)) 92 | for j in range(combat_consts.MAX_NUM_ENEMIES): 93 | actions.append(UsePotion(potion_index=i, target_index=j)) 94 | 95 | for i in range(1, combat_consts.MAX_HAND_SIZE + 1): 96 | actions.append(PlayCard(card_position=i)) 97 | for j in range(combat_consts.MAX_NUM_ENEMIES): 98 | actions.append(PlayCard(card_position=i, target_index=j)) 99 | 100 | for i, action in enumerate(actions): 101 | action._id = i 102 | 103 | return actions 104 | 105 | 106 | ACTIONS = all_actions() 107 | ACTION_SPACE = Discrete(len(ACTIONS)) 108 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/components/combat_reward.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union 4 | 5 | from gymnasium.spaces import Tuple 6 | 7 | import gym_sts.spaces.constants.rewards as reward_consts 8 | from gym_sts.spaces.observations import types 9 | 10 | from .base import ObsComponent 11 | 12 | 13 | class CombatRewardObs(ObsComponent): 14 | def __init__(self, game_state: dict): 15 | # Sane defaults 16 | self.rewards: list[types.Reward] = [] 17 | 18 | screen_type = game_state.get("screen_type") 19 | if screen_type is None: 20 | return 21 | 22 | screen_state = game_state["screen_state"] 23 | if screen_type == "COMBAT_REWARD": 24 | self.rewards = [ 25 | self._parse_reward(reward) for reward in screen_state["rewards"] 26 | ] 27 | elif screen_type == "BOSS_REWARD": 28 | self.rewards = [ 29 | types.RelicReward(value=types.RelicBase(**relic)) 30 | for relic in screen_state["relics"] 31 | ] 32 | 33 | @staticmethod 34 | def space(): 35 | return Tuple([types.Reward.space()] * reward_consts.MAX_NUM_REWARDS) 36 | 37 | @staticmethod 38 | def _parse_reward(reward: dict): 39 | reward_type = reward["reward_type"] 40 | 41 | if reward_type in ["GOLD", "STOLEN_GOLD"]: 42 | return types.GoldReward(value=reward["gold"]) 43 | elif reward_type == "POTION": 44 | potion = types.PotionBase(**reward["potion"]) 45 | return types.PotionReward(value=potion) 46 | elif reward_type == "RELIC": 47 | relic = types.RelicBase(**reward["relic"]) 48 | return types.RelicReward(value=relic) 49 | elif reward_type == "CARD": 50 | return types.CardReward() 51 | elif reward_type in ["EMERALD_KEY", "SAPPHIRE_KEY"]: 52 | # TODO is it important to encode the "link" info for the sapphire key? 53 | key_type = reward_type.split("_")[0] 54 | return types.KeyReward(value=key_type) 55 | else: 56 | raise ValueError(f"Unrecognized reward type {reward_type}") 57 | 58 | def serialize(self) -> list[dict]: 59 | serialized = [types.Reward.serialize_empty()] * reward_consts.MAX_NUM_REWARDS 60 | for i, reward in enumerate(self.rewards): 61 | serialized[i] = reward.serialize() 62 | return serialized 63 | 64 | SerializedState = list[types.Reward.SerializedState] 65 | 66 | @classmethod 67 | def deserialize(cls, data: Union[list[dict], SerializedState]) -> CombatRewardObs: 68 | rewards = [] 69 | for r in data: 70 | try: 71 | reward = types.Reward.deserialize(r) 72 | except types.Reward.NotDeserializable: 73 | continue 74 | rewards.append(reward) 75 | 76 | instance = cls({}) 77 | instance.rewards = rewards 78 | 79 | return instance 80 | 81 | def __eq__(self, other: object) -> bool: 82 | if not isinstance(other, CombatRewardObs): 83 | return False 84 | 85 | attrs = ["rewards"] 86 | 87 | for attr in attrs: 88 | if getattr(self, attr) != getattr(other, attr): 89 | return False 90 | 91 | return True 92 | -------------------------------------------------------------------------------- /gym_sts/spaces/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Dict, List, Tuple 4 | 5 | from gym_sts.constants import PROJECT_ROOT 6 | from gym_sts.spaces.constants.events import ( 7 | ALL_EVENTS, 8 | GLOBALLY_CHECKED_TEXTS, 9 | MAX_NUM_CUSTOM_TEXTS, 10 | ) 11 | 12 | 13 | REMOVED_STRINGS = ["#r", "#y", "#g", "#b", "#p", "@", "~", "NL"] 14 | 15 | 16 | # Utility to match event text with normal text 17 | class EventData: 18 | def __init__(self, file_path: Path): 19 | with open(file_path, "r") as f: 20 | self._data = json.load(f) 21 | self.event_texts: Dict[str, List[str]] = { 22 | # Hardcoded case for Neow Event as it is not in the events.json but instead 23 | # in characters.json 24 | # TODO: Show the individual choices for Neow Event 25 | "Neow Event": [] 26 | } 27 | for event_id, event in self._data.items(): 28 | if "OPTIONS" not in event: 29 | # One of the event texts is malformed; namely, "Proceed Screen" which 30 | # is not mentioned anywhere 31 | continue 32 | results = [] 33 | for desc in event["DESCRIPTIONS"]: 34 | results.append(self.remove_formatting(desc)) 35 | for opt in event["OPTIONS"]: 36 | results.append(self.remove_formatting(opt)) 37 | 38 | # NOTE: The game uses the event_id keys, internally, NOT the names 39 | self.event_texts[event_id] = results 40 | 41 | self.sanity_check() 42 | 43 | def remove_formatting(self, text: str): 44 | for s in REMOVED_STRINGS: 45 | text = text.replace(s, "") 46 | return text 47 | 48 | def find_matches(self, event_id: str, text: str) -> List[Tuple[str, bool]]: 49 | # Returns (str, bool) pairs for matching texts 50 | # The order in which the elements are output should be deterministic 51 | match_list = [] 52 | 53 | for part in GLOBALLY_CHECKED_TEXTS: 54 | match_list.append((part, part in text)) 55 | 56 | for part in self.event_texts[event_id]: 57 | match_list.append((part, part in text)) 58 | 59 | return match_list 60 | 61 | def sanity_check(self): 62 | # Simple sanity check to check if we actually read in the data 63 | assert len(self.event_texts.keys()) > 0 64 | for event in ALL_EVENTS: 65 | if event != "NONE": 66 | if event not in self.event_texts: 67 | raise RuntimeError(f"Event {event} not found in event IDs") 68 | 69 | texts = self.event_texts[event] 70 | 71 | # TODO: Cleanup after you figure out how to handle Neow Event 72 | if len(texts) == 0 and event != "Neow Event": 73 | raise RuntimeError( 74 | f"Event {event} seemingly has no text associated with it" 75 | ) 76 | 77 | if len(texts) > MAX_NUM_CUSTOM_TEXTS: 78 | raise RuntimeError( 79 | f"Event {event} has too many custom texts! " 80 | "({len(texts)}/{MAX_NUM_CUSTOM_TEXTS})" 81 | ) 82 | 83 | 84 | EVENTS_JSON_PATH = PROJECT_ROOT / "data" / "events.json" 85 | EVENT_DATA = EventData(EVENTS_JSON_PATH) 86 | -------------------------------------------------------------------------------- /tests/observations/serde/types/test_cards.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from hypothesis import given 4 | from hypothesis import strategies as st 5 | 6 | from gym_sts.spaces.constants.cards import CardCatalog, CardMetadata 7 | from gym_sts.spaces.observations.types import Card, HandCard, ShopCard 8 | 9 | 10 | @st.composite 11 | def create_card( 12 | draw, card_ids=st.sampled_from(CardCatalog.ids), upgrades=st.booleans() 13 | ): 14 | card_id = draw(card_ids) 15 | upgraded = draw(upgrades) 16 | card_metadata: CardMetadata = getattr(CardCatalog, card_id) 17 | card_props = card_metadata.upgraded if upgraded else card_metadata.unupgraded 18 | 19 | return draw( 20 | st.builds( 21 | Card, 22 | id=st.just(card_id), 23 | name=st.just(card_metadata.name), 24 | cost=st.just(card_props.default_cost), 25 | exhausts=st.just(card_props.exhausts), 26 | ethereal=st.just(card_props.ethereal), 27 | has_target=st.just(card_props.has_target), 28 | upgrades=st.just(int(upgraded)), 29 | ) 30 | ) 31 | 32 | 33 | @given(create_card()) 34 | def test_card_serde_binary(card: Card): 35 | ser = card.serialize() 36 | de = card.deserialize(ser) 37 | 38 | assert card == de 39 | 40 | 41 | @given(create_card()) 42 | def test_card_serde_discrete(card: Card): 43 | ser = card.serialize(discrete=True) 44 | de = card.deserialize(ser) 45 | 46 | assert card == de 47 | 48 | 49 | def test_card_serde_empty_binary(): 50 | ser = Card.serialize_empty() 51 | de = Card.deserialize(ser) 52 | 53 | assert de.id == CardCatalog.NONE.id 54 | 55 | 56 | def test_card_serde_empty_discrete(): 57 | ser = Card.serialize_empty(discrete=True) 58 | de = Card.deserialize(ser) 59 | 60 | assert de.id == CardCatalog.NONE.id 61 | 62 | 63 | def create_card_subclass(Model: Union[type[HandCard], type[ShopCard]]): 64 | @st.composite 65 | def create_subclass(draw, cards=create_card()): 66 | card = draw(cards) 67 | 68 | return draw( 69 | st.builds( 70 | Model, 71 | id=st.just(card.id), 72 | name=st.just(card.name), 73 | cost=st.just(card.cost), 74 | exhausts=st.just(card.exhausts), 75 | ethereal=st.just(card.ethereal), 76 | has_target=st.just(card.has_target), 77 | upgrades=st.just(card.upgrades), 78 | ) 79 | ) 80 | 81 | return create_subclass() 82 | 83 | 84 | @given(create_card_subclass(HandCard)) 85 | def test_hand_card_serde(card: HandCard): 86 | ser = card.serialize() 87 | de = card.deserialize(ser) 88 | 89 | assert card == de 90 | 91 | 92 | def test_hand_card_serde_empty(): 93 | ser = HandCard.serialize_empty() 94 | de = HandCard.deserialize(ser) 95 | 96 | assert de.id == CardCatalog.NONE.id 97 | 98 | 99 | @given(create_card_subclass(ShopCard)) 100 | def test_shop_card_serde(card: ShopCard): 101 | ser = card.serialize() 102 | de = card.deserialize(ser) 103 | 104 | assert card == de 105 | 106 | 107 | def test_shop_potion_serde_empty(): 108 | ser = ShopCard.serialize_empty() 109 | de = ShopCard.deserialize(ser) 110 | 111 | assert de.id == CardCatalog.NONE.id 112 | assert de.price == 0 113 | -------------------------------------------------------------------------------- /tests/observations/test_enemy_serialization.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | 5 | import gym_sts.spaces.constants.combat as combat_consts 6 | from gym_sts.spaces import actions 7 | from gym_sts.spaces.observations import utils 8 | 9 | 10 | def test_enemy_attack_serialization(env): 11 | env.reset(seed=42) 12 | 13 | # First, demonstrate that damage is 0 when the enemy's intent isn't an attack 14 | env.communicator.basemod("fight Hexaghost") 15 | 16 | # Hexaghost's intent on turn 1 is always "unknown" 17 | obs = env.observe(add_to_cache=True) 18 | serialization = obs.combat_state.serialize() 19 | enemies = serialization["enemies"] 20 | assert len(enemies) > 0 21 | hexaghost = enemies[0] 22 | expected_damage = utils.to_binary_array(0, combat_consts.LOG_MAX_ATTACK) 23 | assert np.array_equal(hexaghost["attack"]["damage"], expected_damage) 24 | 25 | # Next, demonstrate that damage and number of hits are correct 26 | # Hexaghost's attack on turn 2 is ((player HP / 12) + 1) * 6 27 | hp = obs.persistent_state.hp 28 | env.communicator.basemod(f"hp lose {hp - 1}") 29 | action_id = actions.ACTIONS.index(actions.EndTurn()) 30 | _, _, _, info = env.step(action_id) 31 | obs = info["observation"] 32 | serialization = obs.combat_state.serialize() 33 | enemies = serialization["enemies"] 34 | assert len(enemies) > 0 35 | hexaghost = enemies[0] 36 | expected_damage = utils.to_binary_array(1, combat_consts.LOG_MAX_ATTACK) 37 | expected_times = utils.to_binary_array(6, combat_consts.LOG_MAX_ATTACK_TIMES) 38 | assert np.array_equal(hexaghost["attack"]["damage"], expected_damage) 39 | assert np.array_equal(hexaghost["attack"]["times"], expected_times) 40 | 41 | # Next, demonstrate that damage may be adjusted from base values 42 | env.communicator.basemod("fight Gremlin_Nob") 43 | env.communicator.basemod("deck remove all") 44 | env.communicator.basemod("deck add Discovery") 45 | time.sleep(0.5) # Wait briefly for card adding animation to complete 46 | obs = env.observe(add_to_cache=True) 47 | _, _, _, info = env.step(action_id) 48 | obs = info["observation"] 49 | serialization = obs.combat_state.serialize() 50 | enemies = serialization["enemies"] 51 | assert len(enemies) > 0 52 | gremlin_nob = enemies[0] 53 | damage = utils.from_binary_array(gremlin_nob["attack"]["damage"]) 54 | 55 | # At this point the only card in hand is Discovery 56 | # TODO make action selection easier 57 | action_id = actions.ACTIONS.index( 58 | actions.PlayCard(card_position=1, target_index=None) 59 | ) 60 | _, _, _, info = env.step(action_id) 61 | obs = info["observation"] 62 | serialization = obs.combat_state.serialize() 63 | enemies = serialization["enemies"] 64 | assert len(enemies) > 0 65 | gremlin_nob = enemies[0] 66 | new_damage = gremlin_nob["attack"]["damage"] 67 | expected_damage = utils.to_binary_array(damage + 2, combat_consts.LOG_MAX_ATTACK) 68 | assert np.array_equal(new_damage, expected_damage) 69 | 70 | # Runic Dome doesn't cause an error 71 | env.communicator.basemod("relic add Runic_Dome") 72 | env.communicator.basemod("fight Looter") 73 | obs = env.observe(add_to_cache=True) 74 | serialization = obs.combat_state.serialize() 75 | enemies = serialization["enemies"] 76 | assert len(enemies) > 0 77 | looter = enemies[0] 78 | damage = utils.from_binary_array(looter["attack"]["damage"]) 79 | times = utils.from_binary_array(looter["attack"]["times"]) 80 | assert damage == 0 81 | assert times == 0 82 | -------------------------------------------------------------------------------- /gym_sts/test_valid_actions.py: -------------------------------------------------------------------------------- 1 | """Tests that an action leads to an error iff it is valid. 2 | 3 | Best to run with ipdb to post-mortem debug errors: 4 | 5 | python -m ipdb -c c gym_sts/test_valid_actions.py 6 | """ 7 | 8 | import argparse 9 | import os 10 | import pickle 11 | import random 12 | import time 13 | 14 | import numpy as np 15 | import tree 16 | 17 | from gym_sts.envs.action_validation import validate 18 | from gym_sts.envs.base import SlayTheSpireGymEnv 19 | from gym_sts.spaces.actions import ACTIONS 20 | 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("lib_dir") 25 | parser.add_argument("mods_dir") 26 | parser.add_argument("out_dir") 27 | parser.add_argument("--build-image", action="store_true") 28 | parser.add_argument("--headless", action="store_true") 29 | parser.add_argument("--render", action="store_true") 30 | parser.add_argument("--runtime", default=30, type=int) 31 | parser.add_argument("--allow_invalid", action="store_true") 32 | parser.add_argument("--screenshots", action="store_true") 33 | parser.add_argument("--dump_states", action="store_true") 34 | 35 | args = parser.parse_args() 36 | 37 | if args.build_image: 38 | SlayTheSpireGymEnv.build_image() 39 | env = SlayTheSpireGymEnv( 40 | args.lib_dir, 41 | args.mods_dir, 42 | args.out_dir, 43 | headless=args.headless, 44 | animate=args.render, 45 | reboot_on_error=True, 46 | ) 47 | env.reset(seed=42) 48 | rng = random.Random(42) 49 | 50 | valid_choices = [True] 51 | if args.allow_invalid: 52 | valid_choices.append(False) 53 | 54 | num_steps = 0 55 | start_time = time.perf_counter() 56 | 57 | if args.dump_states: 58 | states = [] 59 | 60 | while True: 61 | if args.screenshots: 62 | env.screenshot(f"frame{num_steps:03d}.png") 63 | 64 | last_obs = env.observation_cache.get() 65 | assert last_obs is not None 66 | print(last_obs.screen_type) 67 | 68 | want_valid = rng.choice(valid_choices) 69 | 70 | actions = [ 71 | action for action in ACTIONS if validate(action, last_obs) == want_valid 72 | ] 73 | 74 | if len(actions) == 0: 75 | raise ValueError("No %svalid actions!" % ("" if want_valid else "in")) 76 | 77 | action = rng.choice(actions) 78 | print(repr(action)) 79 | try: 80 | serialized, _, done, info = env.step(action._id) 81 | except TimeoutError as e: 82 | run_time = time.perf_counter() - start_time 83 | print(f"Error on step {num_steps} after {run_time} seconds.") 84 | print(e) 85 | if args.headless: 86 | env.screenshot("error.png") 87 | raise e 88 | 89 | assert info["had_error"] != want_valid 90 | 91 | if args.dump_states: 92 | states.append(serialized) 93 | # assert env.observation_space.contains(serialized) 94 | 95 | if done: 96 | print("RESET") 97 | env.reset() 98 | 99 | num_steps += 1 100 | run_time = time.perf_counter() - start_time 101 | if run_time >= args.runtime: 102 | break 103 | 104 | fps = num_steps / run_time 105 | print(f"fps: {fps}") 106 | 107 | if args.dump_states: 108 | states_file = os.path.join(args.out_dir, "states.pkl") 109 | with open(states_file, "wb") as f: 110 | column_major = tree.map_structure(lambda *xs: np.array(xs), *states) 111 | pickle.dump(column_major, f) 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/types/map.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union 4 | 5 | import numpy as np 6 | from gymnasium.spaces import Dict, Discrete, MultiBinary, MultiDiscrete 7 | from pydantic import BaseModel 8 | 9 | import gym_sts.spaces.constants.map as map_consts 10 | 11 | from .base import BinaryArray 12 | 13 | 14 | class MapCoordinates(BaseModel): 15 | x: int 16 | y: int 17 | 18 | 19 | class StandardNode(MapCoordinates): 20 | symbol: str 21 | children: list[MapCoordinates] 22 | 23 | 24 | class EliteNode(StandardNode): 25 | is_burning: bool 26 | 27 | 28 | Node = Union[StandardNode, EliteNode] 29 | 30 | 31 | class Map(BaseModel): 32 | nodes: list[Node] = [] 33 | boss: str = "NONE" # TODO use enum 34 | 35 | @staticmethod 36 | def space() -> Dict: 37 | return Dict( 38 | { 39 | "nodes": MultiDiscrete( 40 | [map_consts.NUM_MAP_LOCATIONS] * map_consts.NUM_MAP_NODES 41 | ), 42 | "edges": MultiBinary(map_consts.NUM_MAP_EDGES), 43 | "boss": Discrete(map_consts.NUM_NORMAL_BOSSES), 44 | } 45 | ) 46 | 47 | def serialize(self) -> dict: 48 | empty_node = map_consts.ALL_MAP_LOCATIONS.index("NONE") 49 | _nodes = np.full([map_consts.NUM_MAP_NODES], empty_node, dtype=np.uint8) 50 | edges = np.zeros([map_consts.NUM_MAP_EDGES], dtype=bool) 51 | 52 | for node in self.nodes: 53 | x, y = node.x, node.y 54 | node_index = map_consts.NUM_MAP_NODES_PER_ROW * y + x 55 | symbol = node.symbol 56 | 57 | if symbol == "E": 58 | # Depends on json field added in our CommunicationMod fork 59 | if isinstance(node, EliteNode) and node.is_burning: 60 | symbol = "B" 61 | 62 | node_type = map_consts.ALL_MAP_LOCATIONS.index(symbol) 63 | _nodes[node_index] = node_type 64 | 65 | if y < map_consts.NUM_MAP_ROWS - 1: 66 | edge_index = node_index * map_consts.NUM_MAP_EDGES_PER_NODE 67 | 68 | child_x_coords = [child.x for child in node.children] 69 | 70 | for coord in [x - 1, x, x + 1]: 71 | if coord in child_x_coords: 72 | edges[edge_index] = True 73 | edge_index += 1 74 | 75 | _boss = map_consts.NORMAL_BOSSES.index(self.boss) 76 | return { 77 | "nodes": _nodes, 78 | "edges": edges, 79 | "boss": _boss, 80 | } 81 | 82 | class SerializedState(BaseModel): 83 | nodes: BinaryArray 84 | edges: BinaryArray 85 | boss: int 86 | 87 | class Config: 88 | arbitrary_types_allowed = True 89 | 90 | @classmethod 91 | def deserialize(cls, data: SerializedState) -> Map: 92 | nodes = [] 93 | for pos, node in enumerate(data.nodes): 94 | node_type = map_consts.ALL_MAP_LOCATIONS[node] 95 | 96 | if node_type == "NONE": 97 | continue 98 | 99 | y, x = divmod(pos, map_consts.NUM_MAP_NODES_PER_ROW) 100 | children = [] 101 | 102 | if y < map_consts.NUM_MAP_ROWS - 1: 103 | edge_index = ( 104 | map_consts.NUM_MAP_NODES_PER_ROW * y + x 105 | ) * map_consts.NUM_MAP_EDGES_PER_NODE 106 | 107 | for i, coord in enumerate([x - 1, x, x + 1]): 108 | if data.edges[edge_index + i]: 109 | children.append({"x": coord, "y": y + 1}) 110 | else: 111 | children.append({"x": 3, "y": y + 2}) 112 | 113 | nodes.append({"symbol": node_type, "children": children, "x": x, "y": y}) 114 | 115 | boss = map_consts.NORMAL_BOSSES[data.boss] 116 | 117 | return cls(nodes=nodes, boss=boss) 118 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSSeenRelics: -------------------------------------------------------------------------------- 1 | { 2 | "Burning Blood": "1", 3 | "Ring of the Snake": "1", 4 | "Cracked Core": "1", 5 | "PureWater": "1", 6 | "Boot": "1", 7 | "Pocketwatch": "1", 8 | "Sling": "1", 9 | "Dream Catcher": "1", 10 | "Odd Mushroom": "1", 11 | "Potion Belt": "1", 12 | "Self Forming Clay": "1", 13 | "Cauldron": "1", 14 | "Philosopher\u0027s Stone": "1", 15 | "Velvet Choker": "1", 16 | "Mark of Pain": "1", 17 | "Oddly Smooth Stone": "1", 18 | "White Beast Statue": "1", 19 | "Pear": "1", 20 | "Pantograph": "1", 21 | "ClockworkSouvenir": "1", 22 | "Bronze Scales": "1", 23 | "Vajra": "1", 24 | "Membership Card": "1", 25 | "Coffee Dripper": "1", 26 | "Runic Dome": "1", 27 | "Astrolabe": "1", 28 | "Tingsha": "1", 29 | "Blood Vial": "1", 30 | "TheAbacus": "1", 31 | "Pen Nib": "1", 32 | "MutagenicStrength": "1", 33 | "MawBank": "1", 34 | "Regal Pillow": "1", 35 | "FossilizedHelix": "1", 36 | "Toolbox": "1", 37 | "Incense Burner": "1", 38 | "Ornamental Fan": "1", 39 | "Orichalcum": "1", 40 | "Snecko Eye": "1", 41 | "Sozu": "1", 42 | "Cursed Key": "1", 43 | "Bag of Marbles": "1", 44 | "Meat on the Bone": "1", 45 | "OrangePellets": "1", 46 | "CultistMask": "1", 47 | "HornCleat": "1", 48 | "Shuriken": "1", 49 | "Thread and Needle": "1", 50 | "Strange Spoon": "1", 51 | "Torii": "1", 52 | "Paper Crane": "1", 53 | "Centennial Puzzle": "1", 54 | "Whetstone": "1", 55 | "InkBottle": "1", 56 | "Matryoshka": "1", 57 | "TwistedFunnel": "1", 58 | "Mummified Hand": "1", 59 | "War Paint": "1", 60 | "Medical Kit": "1", 61 | "Ectoplasm": "1", 62 | "Ice Cream": "1", 63 | "Happy Flower": "1", 64 | "Fusion Hammer": "1", 65 | "Bird Faced Urn": "1", 66 | "Tough Bandages": "1", 67 | "PrismaticShard": "1", 68 | "Peace Pipe": "1", 69 | "Golden Idol": "1", 70 | "Bottled Lightning": "1", 71 | "Bottled Tornado": "1", 72 | "Old Coin": "1", 73 | "Lantern": "1", 74 | "Runic Cube": "1", 75 | "Tiny House": "1", 76 | "MealTicket": "1", 77 | "Nunchaku": "1", 78 | "Juzu Bracelet": "1", 79 | "Unceasing Top": "1", 80 | "Question Card": "1", 81 | "Brimstone": "1", 82 | "Empty Cage": "1", 83 | "Girya": "1", 84 | "Letter Opener": "1", 85 | "Gremlin Horn": "1", 86 | "Necronomicon": "1", 87 | "Snake Skull": "1", 88 | "Mercury Hourglass": "1", 89 | "Strawberry": "1", 90 | "Sundial": "1", 91 | "Orrery": "1", 92 | "Kunai": "1", 93 | "Du-Vu Doll": "1", 94 | "Smiling Mask": "1", 95 | "Tiny Chest": "1", 96 | "Black Star": "1", 97 | "Ninja Scroll": "1", 98 | "Calipers": "1", 99 | "CaptainsWheel": "1", 100 | "Runic Pyramid": "1", 101 | "DollysMirror": "1", 102 | "Toy Ornithopter": "1", 103 | "Darkstone Periapt": "1", 104 | "Chemical X": "1", 105 | "Anchor": "1", 106 | "SlaversCollar": "1", 107 | "Calling Bell": "1", 108 | "StoneCalendar": "1", 109 | "Bag of Preparation": "1", 110 | "HandDrill": "1", 111 | "Red Mask": "1", 112 | "Toxic Egg 2": "1", 113 | "TungstenRod": "1", 114 | "Busted Crown": "1", 115 | "Lizard Tail": "1", 116 | "Damaru": "1", 117 | "Gambling Chip": "1", 118 | "Eternal Feather": "1", 119 | "SacredBark": "1", 120 | "NeowsBlessing": "1", 121 | "WristBlade": "1", 122 | "StrikeDummy": "1", 123 | "Frozen Egg 2": "1", 124 | "PreservedInsect": "1", 125 | "Lee\u0027s Waffle": "1", 126 | "Charon\u0027s Ashes": "1", 127 | "Omamori": "1", 128 | "Prayer Wheel": "1", 129 | "Shovel": "1", 130 | "Inserter": "1", 131 | "Frozen Eye": "1", 132 | "Bottled Flame": "1", 133 | "Pandora\u0027s Box": "1", 134 | "Blue Candle": "1", 135 | "Dead Branch": "1", 136 | "Singing Bowl": "1", 137 | "Enchiridion": "1", 138 | "VioletLotus": "1", 139 | "Ancient Tea Set": "1", 140 | "Magic Flower": "1", 141 | "Paper Frog": "1", 142 | "Black Blood": "1", 143 | "Art of War": "1", 144 | "Molten Egg 2": "1", 145 | "The Courier": "1", 146 | "Ring of the Serpent": "1", 147 | "Mango": "1", 148 | "Ginger": "1", 149 | "Red Skull": "1", 150 | "Nloth\u0027s Gift": "1", 151 | "HoveringKite": "1", 152 | "The Specimen": "1" 153 | } 154 | -------------------------------------------------------------------------------- /gym_sts/communication/communicator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from pathlib import Path 4 | 5 | from gym_sts.communication.receiver import Receiver 6 | from gym_sts.communication.sender import Sender 7 | from gym_sts.spaces.observations import Observation 8 | 9 | 10 | def init_fifos(filenames): 11 | # Create fifos for communication 12 | for f in filenames: 13 | if os.path.exists(f): 14 | os.remove(f) 15 | os.mkfifo(f) 16 | 17 | 18 | class Communicator: 19 | def __init__(self, input_path: Path, output_path: Path): 20 | self.input_path = input_path 21 | self.output_path = output_path 22 | init_fifos([self.input_path, self.output_path]) 23 | self.receiver = Receiver(self.output_path) 24 | self.sender = Sender(self.input_path) 25 | 26 | def _manual_command(self, action: str) -> Observation: 27 | self.receiver.empty_fifo() 28 | self.sender._send_message(action) 29 | state = self.receiver.receive_game_state() 30 | return Observation(state) 31 | 32 | def ready(self) -> None: 33 | self.sender.send_ready() 34 | 35 | def choose(self, choice) -> Observation: 36 | self.receiver.empty_fifo() 37 | self.sender.send_choose(choice) 38 | state = self.receiver.receive_game_state() 39 | return Observation(state) 40 | 41 | def click(self, x: int, y: int, left: bool = True) -> Observation: 42 | self.receiver.empty_fifo() 43 | self.sender.send_click(x, y, left=left) 44 | state = self.receiver.receive_game_state() 45 | return Observation(state) 46 | 47 | def end(self) -> Observation: 48 | self.receiver.empty_fifo() 49 | self.sender.send_end() 50 | state = self.receiver.receive_game_state() 51 | return Observation(state) 52 | 53 | def potion(self, action, slot, target) -> Observation: 54 | self.receiver.empty_fifo() 55 | self.sender.send_potion(action, slot, target) 56 | state = self.receiver.receive_game_state() 57 | return Observation(state) 58 | 59 | def proceed(self) -> Observation: 60 | self.receiver.empty_fifo() 61 | self.sender.send_proceed() 62 | state = self.receiver.receive_game_state() 63 | return Observation(state) 64 | 65 | def resign(self) -> Observation: 66 | self.receiver.empty_fifo() 67 | self.sender.send_resign() 68 | state = self.receiver.receive_game_state() 69 | return Observation(state) 70 | 71 | def start(self, player_class: str, ascension: int, seed: str) -> Observation: 72 | self.receiver.empty_fifo() 73 | self.sender.send_start(player_class, ascension, seed) 74 | 75 | tries = 3 76 | for _ in range(tries): 77 | state = self.receiver.receive_game_state() 78 | if state["in_game"]: 79 | return Observation(state) 80 | 81 | time.sleep(0.05) 82 | 83 | raise TimeoutError("Waited for game to start, but it didn't happen.") 84 | 85 | def state(self) -> Observation: 86 | """ 87 | Get the JSON representation of the current game state, regardless of whether or 88 | not the game is "stable." This method is valid in all game states. 89 | """ 90 | 91 | self.receiver.empty_fifo() 92 | self.sender.send_state() 93 | state = self.receiver.receive_game_state() 94 | return Observation(state) 95 | 96 | def wait(self, frames: int) -> Observation: 97 | self.receiver.empty_fifo() 98 | self.sender.send_wait(frames) 99 | state = self.receiver.receive_game_state() 100 | return Observation(state) 101 | 102 | def basemod(self, command: str) -> Observation: 103 | """ 104 | Send a command to the basemod console and return the subsequent game state. 105 | 106 | WARNING: The returned observation is not guaranteed to fully reflect the results 107 | of the basemod command. Some actions, like adding cards, queue an animation that 108 | CommunicationMod does not wait for before sending the next state. 109 | """ 110 | 111 | self.receiver.empty_fifo() 112 | self.sender.send_basemod(command) 113 | state = self.receiver.receive_game_state() 114 | return Observation(state) 115 | 116 | def render(self, render: bool) -> None: 117 | """ 118 | Toggle whether or not the game should render to the screen. 119 | 120 | Note that this command does not return a state response. 121 | """ 122 | 123 | self.sender.send_render(render) 124 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/types/potions.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union 4 | 5 | import numpy as np 6 | from gymnasium.spaces import Dict, Discrete, MultiBinary 7 | from pydantic import BaseModel 8 | 9 | import gym_sts.spaces.constants.potions as potion_consts 10 | import gym_sts.spaces.constants.shop as shop_consts 11 | from gym_sts.spaces.constants.potions import PotionCatalog 12 | from gym_sts.spaces.observations import utils 13 | 14 | from .base import BinaryArray, ShopMixin 15 | 16 | 17 | class PotionBase(BaseModel): 18 | id: str 19 | name: str 20 | requires_target: bool 21 | 22 | @classmethod 23 | def _serialize(cls, potion_id: str, discrete=False) -> Union[BinaryArray, int]: 24 | potion_idx = PotionCatalog.ids.index(potion_id) 25 | if discrete: 26 | return potion_idx 27 | else: 28 | return utils.to_binary_array(potion_idx, potion_consts.LOG_NUM_POTIONS) 29 | 30 | @classmethod 31 | def serialize_empty(cls, discrete=False) -> Union[BinaryArray, int]: 32 | return cls._serialize(PotionCatalog.NONE.id, discrete=discrete) 33 | 34 | def serialize(self, discrete=False) -> Union[BinaryArray, int]: 35 | return self._serialize(self.id, discrete=discrete) 36 | 37 | @classmethod 38 | def deserialize(cls, potion_idx: Union[int, BinaryArray]) -> PotionBase: 39 | # Can't check if instance of BinaryArray, because it's not a real class? 40 | if isinstance(potion_idx, np.ndarray): 41 | potion_idx = utils.from_binary_array(potion_idx) 42 | 43 | potion_id = PotionCatalog.ids[potion_idx] 44 | potion_meta: potion_consts.PotionMetadata = getattr(PotionCatalog, potion_id) 45 | 46 | return cls( 47 | id=potion_id, 48 | name=potion_meta.name, 49 | requires_target=potion_meta.requires_target, 50 | ) 51 | 52 | 53 | class Potion(PotionBase): 54 | can_use: bool 55 | can_discard: bool 56 | 57 | @staticmethod 58 | def space() -> Dict: 59 | return Dict( 60 | { 61 | "id": MultiBinary(potion_consts.LOG_NUM_POTIONS), 62 | "can_use": Discrete(2), 63 | "can_discard": Discrete(2), 64 | } 65 | ) 66 | 67 | @classmethod 68 | def serialize_empty(cls) -> dict: # type: ignore[override] 69 | return { 70 | "id": super().serialize_empty(), 71 | "can_use": 0, 72 | "can_discard": 0, 73 | } 74 | 75 | def serialize(self) -> dict: # type: ignore[override] 76 | return { 77 | "id": super().serialize(), 78 | "can_use": int(self.can_use), 79 | "can_discard": int(self.can_discard), 80 | } 81 | 82 | class SerializedState(BaseModel): 83 | id: BinaryArray 84 | can_use: int 85 | can_discard: int 86 | 87 | class Config: 88 | arbitrary_types_allowed = True 89 | 90 | @classmethod 91 | def deserialize( # type: ignore[override] 92 | cls, data: Union[dict, SerializedState] 93 | ) -> Potion: 94 | if not isinstance(data, cls.SerializedState): 95 | data = cls.SerializedState(**data) 96 | 97 | potion_base = PotionBase.deserialize(data.id) 98 | can_use = bool(data.can_use) 99 | can_discard = bool(data.can_discard) 100 | 101 | return cls( 102 | id=potion_base.id, 103 | name=potion_base.name, 104 | requires_target=potion_base.requires_target, 105 | can_use=can_use, 106 | can_discard=can_discard, 107 | ) 108 | 109 | 110 | class ShopPotion(PotionBase, ShopMixin): 111 | @staticmethod 112 | def space() -> Dict: 113 | return Dict( 114 | { 115 | "potion": MultiBinary(potion_consts.LOG_NUM_POTIONS), 116 | "price": MultiBinary(shop_consts.SHOP_LOG_MAX_PRICE), 117 | } 118 | ) 119 | 120 | @classmethod 121 | def serialize_empty(cls) -> dict: # type: ignore[override] 122 | return { 123 | "potion": super().serialize_empty(), 124 | "price": cls.serialize_empty_price(), 125 | } 126 | 127 | def serialize(self) -> dict: # type: ignore[override] 128 | return {"potion": super().serialize(), "price": self.serialize_price()} 129 | 130 | class SerializedState(BaseModel): 131 | potion: BinaryArray 132 | price: BinaryArray 133 | 134 | class Config: 135 | arbitrary_types_allowed = True 136 | 137 | @classmethod 138 | def deserialize( # type: ignore[override] 139 | cls, data: Union[dict, SerializedState] 140 | ) -> ShopPotion: 141 | if not isinstance(data, cls.SerializedState): 142 | data = cls.SerializedState(**data) 143 | 144 | potion = PotionBase.deserialize(data.potion) 145 | price = ShopMixin.deserialize_price(data.price) 146 | 147 | return cls(**potion.dict(), price=price) 148 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/components/shop.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union 4 | 5 | from gymnasium.spaces import Dict, Discrete, MultiBinary, Space, Tuple 6 | from pydantic import BaseModel 7 | 8 | import gym_sts.spaces.constants.cards as card_consts 9 | import gym_sts.spaces.constants.shop as shop_consts 10 | from gym_sts.spaces.constants.cards import CardCatalog 11 | from gym_sts.spaces.constants.potions import PotionCatalog 12 | from gym_sts.spaces.constants.relics import RelicCatalog 13 | from gym_sts.spaces.observations import types, utils 14 | 15 | from .base import PydanticComponent 16 | 17 | 18 | class SerializedPurge(BaseModel): 19 | available: int 20 | price: types.BinaryArray 21 | 22 | class Config: 23 | arbitrary_types_allowed = True 24 | 25 | 26 | class ShopObs(PydanticComponent): 27 | cards: list[types.ShopCard] = [] 28 | relics: list[types.ShopRelic] = [] 29 | potions: list[types.ShopPotion] = [] 30 | purge_available: bool = False 31 | purge_cost: int = 0 32 | 33 | @staticmethod 34 | def space() -> Space: 35 | return Dict( 36 | { 37 | "cards": Tuple( 38 | [ 39 | Dict( 40 | { 41 | "card": MultiBinary( 42 | card_consts.LOG_NUM_CARDS_WITH_UPGRADES 43 | ), 44 | "price": MultiBinary(shop_consts.SHOP_LOG_MAX_PRICE), 45 | } 46 | ) 47 | ] 48 | * shop_consts.SHOP_CARD_COUNT, 49 | ), 50 | "relics": Tuple( 51 | [types.ShopRelic.space()] * shop_consts.SHOP_RELIC_COUNT 52 | ), 53 | "potions": Tuple( 54 | [types.ShopPotion.space()] * shop_consts.SHOP_POTION_COUNT 55 | ), 56 | "purge": Dict( 57 | { 58 | "available": Discrete(2), 59 | "price": MultiBinary(shop_consts.SHOP_LOG_MAX_PRICE), 60 | } 61 | ), 62 | } 63 | ) 64 | 65 | def serialize(self) -> dict: 66 | serialized_cards = [ 67 | types.ShopCard.serialize_empty() 68 | ] * shop_consts.SHOP_CARD_COUNT 69 | for i, card in enumerate(self.cards): 70 | serialized_cards[i] = card.serialize() 71 | 72 | serialized_relics = [ 73 | types.ShopRelic.serialize_empty() 74 | ] * shop_consts.SHOP_RELIC_COUNT 75 | for i, relic in enumerate(self.relics): 76 | serialized_relics[i] = relic.serialize() 77 | 78 | serialized_potions = [ 79 | types.ShopPotion.serialize_empty() 80 | ] * shop_consts.SHOP_POTION_COUNT 81 | for i, potion in enumerate(self.potions): 82 | serialized_potions[i] = potion.serialize() 83 | 84 | serialized_purge = { 85 | "available": int(self.purge_available), 86 | "price": utils.to_binary_array( 87 | self.purge_cost, shop_consts.SHOP_LOG_MAX_PRICE 88 | ), 89 | } 90 | 91 | return { 92 | "cards": serialized_cards, 93 | "relics": serialized_relics, 94 | "potions": serialized_potions, 95 | "purge": serialized_purge, 96 | } 97 | 98 | class SerializedState(BaseModel): 99 | cards: list[types.ShopCard.SerializedState] 100 | relics: list[types.ShopRelic.SerializedState] 101 | potions: list[types.ShopPotion.SerializedState] 102 | purge: SerializedPurge 103 | 104 | @classmethod 105 | def deserialize(cls, data: Union[dict, SerializedState]) -> ShopObs: 106 | if not isinstance(data, cls.SerializedState): 107 | data = cls.SerializedState(**data) 108 | 109 | cards = [] 110 | for serialized_card in data.cards: 111 | shop_card = types.ShopCard.deserialize(serialized_card) 112 | if shop_card.id != CardCatalog.NONE.id: 113 | cards.append(shop_card) 114 | 115 | relics = [] 116 | for serialized_relic in data.relics: 117 | shop_relic = types.ShopRelic.deserialize(serialized_relic) 118 | if shop_relic.id != RelicCatalog.NONE.id: 119 | relics.append(shop_relic) 120 | 121 | potions = [] 122 | for serialized_potion in data.potions: 123 | shop_potion = types.ShopPotion.deserialize(serialized_potion) 124 | if shop_potion.id != PotionCatalog.NONE.id: 125 | potions.append(shop_potion) 126 | 127 | purge_available = bool(data.purge.available) 128 | purge_cost = utils.from_binary_array(data.purge.price) 129 | 130 | return cls( 131 | cards=cards, 132 | relics=relics, 133 | potions=potions, 134 | purge_available=purge_available, 135 | purge_cost=purge_cost, 136 | ) 137 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/types/relics.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union 4 | 5 | import numpy as np 6 | from gymnasium.spaces import Dict, Discrete, MultiBinary 7 | from pydantic import BaseModel, validator 8 | 9 | import gym_sts.spaces.constants.relics as relic_consts 10 | import gym_sts.spaces.constants.shop as shop_consts 11 | from gym_sts.spaces.constants.relics import RelicCatalog 12 | from gym_sts.spaces.observations import utils 13 | 14 | from .base import BinaryArray, ShopMixin 15 | 16 | 17 | class RelicBase(BaseModel): 18 | id: str 19 | name: str 20 | 21 | @classmethod 22 | def _serialize(cls, relic_id: str, discrete=False) -> Union[BinaryArray, int]: 23 | relic_idx = RelicCatalog.ids.index(relic_id) 24 | if discrete: 25 | return relic_idx 26 | else: 27 | return utils.to_binary_array(relic_idx, relic_consts.LOG_NUM_RELICS) 28 | 29 | @classmethod 30 | def serialize_empty(cls, discrete=False) -> Union[BinaryArray, int]: 31 | return cls._serialize(RelicCatalog.NONE.id, discrete=discrete) 32 | 33 | def serialize(self, discrete=False) -> Union[BinaryArray, int]: 34 | return self._serialize(self.id, discrete=discrete) 35 | 36 | @classmethod 37 | def deserialize(cls, relic_idx: Union[int, BinaryArray]) -> RelicBase: 38 | if isinstance(relic_idx, np.ndarray): 39 | relic_idx = utils.from_binary_array(relic_idx) 40 | 41 | relic_id = RelicCatalog.ids[relic_idx] 42 | relic_meta: relic_consts.RelicMetadata = getattr(RelicCatalog, relic_id) 43 | 44 | return cls(id=relic_id, name=relic_meta.name) 45 | 46 | def __lt__(self, other: object) -> bool: 47 | if not isinstance(other, RelicBase): 48 | return NotImplemented 49 | 50 | return self.id < other.id 51 | 52 | 53 | class Relic(RelicBase): 54 | counter: int = 0 55 | 56 | @validator("counter", pre=True) 57 | def must_be_nonnegative(cls, v: int) -> int: 58 | # STS uses values of -1 and -2 (to indicate the relic has no counter, or that 59 | # it's exhausted, respectively), so we pad by 3 so all values are positive. 60 | # The "NONE" relic should have a value of 0. 61 | return v + 3 62 | 63 | @staticmethod 64 | def space() -> Dict: 65 | return Dict( 66 | { 67 | "id": Discrete(relic_consts.NUM_RELICS), 68 | "counter": MultiBinary(relic_consts.LOG_MAX_COUNTER), 69 | } 70 | ) 71 | 72 | @classmethod 73 | def serialize_empty(cls) -> dict: # type: ignore[override] 74 | return { 75 | "id": super().serialize_empty(discrete=True), 76 | "counter": utils.to_binary_array(0, relic_consts.LOG_MAX_COUNTER), 77 | } 78 | 79 | def serialize(self) -> dict: # type: ignore[override] 80 | return { 81 | "id": super().serialize(discrete=True), 82 | "counter": utils.to_binary_array( 83 | min(self.counter, relic_consts.MAX_COUNTER), 84 | relic_consts.LOG_MAX_COUNTER, 85 | ), 86 | } 87 | 88 | class SerializedState(BaseModel): 89 | id: int 90 | counter: BinaryArray 91 | 92 | class Config: 93 | arbitrary_types_allowed = True 94 | 95 | @classmethod 96 | def deserialize( # type: ignore[override] 97 | cls, data: Union[dict, SerializedState] 98 | ) -> Relic: 99 | if not isinstance(data, cls.SerializedState): 100 | data = cls.SerializedState(**data) 101 | 102 | relic_base = RelicBase.deserialize(data.id) 103 | counter = utils.from_binary_array(data.counter) - 3 104 | 105 | return cls(id=relic_base.id, name=relic_base.name, counter=counter) 106 | 107 | 108 | class ShopRelic(RelicBase, ShopMixin): 109 | @staticmethod 110 | def space() -> Dict: 111 | return Dict( 112 | { 113 | "relic": MultiBinary(relic_consts.LOG_NUM_RELICS), 114 | "price": MultiBinary(shop_consts.SHOP_LOG_MAX_PRICE), 115 | } 116 | ) 117 | 118 | @classmethod 119 | def serialize_empty(cls) -> dict: # type: ignore[override] 120 | return { 121 | "relic": super().serialize_empty(), 122 | "price": cls.serialize_empty_price(), 123 | } 124 | 125 | def serialize(self) -> dict: # type: ignore[override] 126 | return {"relic": super().serialize(), "price": self.serialize_price()} 127 | 128 | class SerializedState(BaseModel): 129 | relic: BinaryArray 130 | price: BinaryArray 131 | 132 | class Config: 133 | arbitrary_types_allowed = True 134 | 135 | @classmethod 136 | def deserialize( # type: ignore[override] 137 | cls, data: Union[dict, SerializedState] 138 | ) -> ShopRelic: 139 | if not isinstance(data, cls.SerializedState): 140 | data = cls.SerializedState(**data) 141 | 142 | relic = RelicBase.deserialize(data.relic) 143 | price = ShopMixin.deserialize_price(data.price) 144 | 145 | return cls(**relic.dict(), price=price) 146 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/types/cards.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Literal, Union 4 | 5 | import numpy as np 6 | from gymnasium.spaces import Dict, Discrete, MultiBinary 7 | from pydantic import BaseModel, Field, NonNegativeInt, validator 8 | 9 | import gym_sts.spaces.constants.cards as card_consts 10 | from gym_sts.spaces.constants.cards import CardCatalog 11 | from gym_sts.spaces.observations import utils 12 | 13 | from .base import BinaryArray, ShopMixin 14 | 15 | 16 | class Card(BaseModel): 17 | exhausts: bool 18 | cost: Union[NonNegativeInt, Literal["U", "X"]] 19 | name: str 20 | id: str 21 | ethereal: bool 22 | upgrades: NonNegativeInt 23 | has_target: bool 24 | 25 | @validator("cost", pre=True) 26 | def decode_special_costs(cls, v: int) -> Union[NonNegativeInt, Literal["U", "X"]]: 27 | if v == -1: 28 | return "X" 29 | elif v == -2: 30 | return "U" 31 | return v 32 | 33 | @classmethod 34 | def _serialize( 35 | cls, card_id: str, upgrades: int, discrete=False 36 | ) -> Union[BinaryArray, int]: 37 | card_idx = CardCatalog.ids.index(card_id) 38 | if discrete: 39 | card_idx *= 2 40 | if upgrades > 0: 41 | card_idx += 1 42 | 43 | return card_idx 44 | 45 | else: 46 | array = utils.to_binary_array(card_idx, card_consts.LOG_NUM_CARDS) 47 | 48 | # TODO support more than 1 upgrade 49 | upgrade_bit = [0] 50 | if upgrades > 0: 51 | upgrade_bit = [1] 52 | 53 | array = np.concatenate([upgrade_bit, array], axis=0) 54 | 55 | return array 56 | 57 | @classmethod 58 | def serialize_empty(cls, discrete=False) -> Union[BinaryArray, int]: 59 | upgrades = 0 60 | return cls._serialize(CardCatalog.NONE.id, upgrades, discrete=discrete) 61 | 62 | def serialize(self, discrete=False) -> Union[BinaryArray, int]: 63 | return self._serialize(self.id, self.upgrades, discrete=discrete) 64 | 65 | @classmethod 66 | def deserialize(cls, ser_data: Union[int, BinaryArray]) -> Card: 67 | if isinstance(ser_data, np.ndarray): 68 | if len(ser_data) != card_consts.LOG_NUM_CARDS + 1: 69 | raise ValueError("Card encoding has unexpected length") 70 | 71 | upgraded = ser_data[0] 72 | card_bits = ser_data[1:] 73 | card_idx = utils.from_binary_array(card_bits) 74 | else: 75 | card_idx, upgraded = divmod(ser_data, 2) 76 | 77 | card_id = card_consts.CardCatalog.ids[card_idx] 78 | card_meta: card_consts.CardMetadata = getattr(card_consts.CardCatalog, card_id) 79 | card_props = card_meta.upgraded if upgraded else card_meta.unupgraded 80 | 81 | return cls( 82 | id=card_id, 83 | name=card_meta.name, 84 | # TODO may be wrong because we don't currently serialize cost 85 | cost=card_props.default_cost, 86 | exhausts=card_props.exhausts, 87 | ethereal=card_props.ethereal, 88 | has_target=card_props.has_target, 89 | # TODO may be wrong for cards that can be upgraded 2+ times 90 | upgrades=upgraded, 91 | ) 92 | 93 | def __lt__(self, other: object) -> bool: 94 | if not isinstance(other, Card): 95 | return NotImplemented 96 | 97 | if self.id != other.id: 98 | return self.id < other.id 99 | 100 | return self.upgrades < other.upgrades 101 | 102 | 103 | class HandCard(Card): 104 | is_playable: bool 105 | 106 | @staticmethod 107 | def space() -> Dict: 108 | return Dict( 109 | { 110 | "card": MultiBinary(card_consts.LOG_NUM_CARDS_WITH_UPGRADES), 111 | "is_playable": Discrete(2), 112 | } 113 | ) 114 | 115 | @classmethod 116 | def serialize_empty(cls) -> dict: # type: ignore[override] 117 | return { 118 | "card": super().serialize_empty(), 119 | "is_playable": 0, 120 | } 121 | 122 | def serialize(self) -> dict: # type: ignore[override] 123 | return { 124 | "card": super().serialize(), 125 | "is_playable": int(self.is_playable), 126 | } 127 | 128 | class SerializedState(BaseModel): 129 | card: BinaryArray 130 | is_playable: int = Field(..., ge=0, le=1) 131 | 132 | class Config: 133 | arbitrary_types_allowed = True 134 | 135 | @classmethod 136 | def deserialize( # type: ignore[override] 137 | cls, data: Union[dict, SerializedState] 138 | ) -> HandCard: 139 | if not isinstance(data, cls.SerializedState): 140 | data = cls.SerializedState(**data) 141 | 142 | card = Card.deserialize(data.card) 143 | return cls(**card.dict(), is_playable=bool(data.is_playable)) 144 | 145 | 146 | class ShopCard(Card, ShopMixin): 147 | @classmethod 148 | def serialize_empty(cls) -> dict: # type: ignore[override] 149 | return { 150 | "card": super().serialize_empty(), 151 | "price": cls.serialize_empty_price(), 152 | } 153 | 154 | def serialize(self) -> dict: # type: ignore[override] 155 | return {"card": super().serialize(), "price": self.serialize_price()} 156 | 157 | class SerializedState(BaseModel): 158 | card: BinaryArray 159 | price: BinaryArray 160 | 161 | class Config: 162 | arbitrary_types_allowed = True 163 | 164 | @classmethod 165 | def deserialize( # type: ignore[override] 166 | cls, data: Union[dict, SerializedState] 167 | ) -> ShopCard: 168 | if not isinstance(data, cls.SerializedState): 169 | data = cls.SerializedState(**data) 170 | 171 | card = Card.deserialize(data.card) 172 | price = ShopMixin.deserialize_price(data.price) 173 | 174 | return cls(**card.dict(), price=price) 175 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/types/rewards.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Union 5 | 6 | from gymnasium.spaces import Dict, Discrete, MultiBinary 7 | from pydantic import BaseModel, Field 8 | 9 | import gym_sts.spaces.constants.base as base_consts 10 | import gym_sts.spaces.constants.rewards as reward_consts 11 | from gym_sts.spaces.constants.relics import RelicCatalog 12 | from gym_sts.spaces.observations import utils 13 | 14 | from .base import BinaryArray 15 | from .potions import PotionBase 16 | from .relics import RelicBase 17 | 18 | 19 | class Reward(BaseModel, ABC): 20 | @staticmethod 21 | def space(): 22 | return Dict( 23 | { 24 | "type": Discrete(reward_consts.NUM_REWARD_TYPES), 25 | # Could be a gold value, a relic ID, the color of a key, or a potion ID 26 | "value": MultiBinary(reward_consts.COMBAT_REWARD_LOG_MAX_ID), 27 | } 28 | ) 29 | 30 | @abstractmethod 31 | def serialize(self) -> dict: 32 | raise NotImplementedError("Unimplemented") 33 | 34 | @staticmethod 35 | def serialize_empty(): 36 | return { 37 | "type": reward_consts.RewardType.NONE, 38 | "value": utils.to_binary_array(0, reward_consts.COMBAT_REWARD_LOG_MAX_ID), 39 | } 40 | 41 | class SerializedState(BaseModel): 42 | type: int 43 | value: BinaryArray 44 | 45 | class Config: 46 | arbitrary_types_allowed = True 47 | 48 | class NotDeserializable(Exception): 49 | pass 50 | 51 | @classmethod 52 | def deserialize(cls, data: Union[dict, SerializedState]) -> Reward: 53 | if not isinstance(data, cls.SerializedState): 54 | data = cls.SerializedState(**data) 55 | 56 | reward_type = reward_consts.RewardType(data.type) 57 | 58 | if reward_type == reward_consts.RewardType.GOLD: 59 | return GoldReward.deserialize(data) 60 | elif reward_type == reward_consts.RewardType.POTION: 61 | return PotionReward.deserialize(data) 62 | elif reward_type == reward_consts.RewardType.RELIC: 63 | return RelicReward.deserialize(data) 64 | elif reward_type == reward_consts.RewardType.CARD: 65 | return CardReward.deserialize(data) 66 | elif reward_type == reward_consts.RewardType.KEY: 67 | return KeyReward.deserialize(data) 68 | elif reward_type == reward_consts.RewardType.NONE: 69 | raise cls.NotDeserializable() 70 | else: 71 | raise ValueError(f"Unrecognized reward type {reward_type}") 72 | 73 | 74 | class GoldReward(Reward): 75 | value: int = Field(..., ge=0, lt=2**reward_consts.COMBAT_REWARD_LOG_MAX_ID) 76 | 77 | def serialize(self) -> dict: 78 | return { 79 | "type": reward_consts.RewardType.GOLD, 80 | "value": utils.to_binary_array( 81 | self.value, reward_consts.COMBAT_REWARD_LOG_MAX_ID 82 | ), 83 | } 84 | 85 | @classmethod 86 | def deserialize(cls, data: Union[dict, Reward.SerializedState]) -> GoldReward: 87 | if not isinstance(data, cls.SerializedState): 88 | data = cls.SerializedState(**data) 89 | 90 | value = utils.from_binary_array(data.value) 91 | 92 | return cls(value=value) 93 | 94 | 95 | class PotionReward(Reward): 96 | value: PotionBase 97 | 98 | def serialize(self) -> dict: 99 | potion_idx = self.value.serialize(discrete=True) 100 | assert isinstance(potion_idx, int) 101 | 102 | return { 103 | "type": reward_consts.RewardType.POTION, 104 | "value": utils.to_binary_array( 105 | potion_idx, reward_consts.COMBAT_REWARD_LOG_MAX_ID 106 | ), 107 | } 108 | 109 | @classmethod 110 | def deserialize(cls, data: Union[dict, Reward.SerializedState]) -> PotionReward: 111 | if not isinstance(data, cls.SerializedState): 112 | data = cls.SerializedState(**data) 113 | 114 | potion_idx = utils.from_binary_array(data.value) 115 | potion = PotionBase.deserialize(potion_idx) 116 | 117 | return cls(value=potion) 118 | 119 | 120 | class RelicReward(Reward): 121 | value: RelicBase 122 | 123 | def serialize(self) -> dict: 124 | relic_idx = RelicCatalog.ids.index(self.value.id) 125 | return { 126 | "type": reward_consts.RewardType.RELIC, 127 | "value": utils.to_binary_array( 128 | relic_idx, reward_consts.COMBAT_REWARD_LOG_MAX_ID 129 | ), 130 | } 131 | 132 | @classmethod 133 | def deserialize(cls, data: Union[dict, Reward.SerializedState]) -> RelicReward: 134 | if not isinstance(data, cls.SerializedState): 135 | data = cls.SerializedState(**data) 136 | 137 | relic_idx = utils.from_binary_array(data.value) 138 | relic = RelicBase.deserialize(relic_idx) 139 | 140 | return cls(value=relic) 141 | 142 | 143 | class CardReward(Reward): 144 | def serialize(self) -> dict: 145 | return { 146 | "type": reward_consts.RewardType.CARD, 147 | "value": utils.to_binary_array(0, reward_consts.COMBAT_REWARD_LOG_MAX_ID), 148 | } 149 | 150 | @classmethod 151 | def deserialize(cls, data: Union[dict, Reward.SerializedState]) -> CardReward: 152 | if not isinstance(data, cls.SerializedState): 153 | data = cls.SerializedState(**data) 154 | 155 | return cls() 156 | 157 | 158 | class KeyReward(Reward): 159 | value: str 160 | 161 | def serialize(self) -> dict: 162 | key_idx = base_consts.ALL_KEYS.index(self.value) 163 | return { 164 | "type": reward_consts.RewardType.KEY, 165 | "value": utils.to_binary_array( 166 | key_idx, reward_consts.COMBAT_REWARD_LOG_MAX_ID 167 | ), 168 | } 169 | 170 | @classmethod 171 | def deserialize(cls, data: Union[dict, Reward.SerializedState]) -> KeyReward: 172 | if not isinstance(data, cls.SerializedState): 173 | data = cls.SerializedState(**data) 174 | 175 | key_idx = utils.from_binary_array(data.value) 176 | key = base_consts.ALL_KEYS[key_idx] 177 | 178 | return cls(value=key) 179 | -------------------------------------------------------------------------------- /gym_sts/spaces/constants/combat.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | MAX_ATTACK = 999 5 | MAX_ATTACK_TIMES = 15 6 | MAX_BLOCK = 999 7 | MAX_EFFECT = 999 8 | MAX_HAND_SIZE = 10 9 | MAX_NUM_ENEMIES = 6 10 | 11 | LOG_MAX_ATTACK = math.ceil(math.log(MAX_ATTACK, 2)) 12 | LOG_MAX_ATTACK_TIMES = math.ceil(math.log(MAX_ATTACK_TIMES, 2)) 13 | LOG_MAX_BLOCK = math.ceil(math.log(MAX_BLOCK, 2)) 14 | LOG_MAX_EFFECT = math.ceil(math.log(MAX_EFFECT, 2)) 15 | LOG_MAX_ENERGY = 6 16 | LOG_MAX_TURN = 8 17 | 18 | ALL_EFFECTS = [ 19 | "Conserve", 20 | "Sharp Hide", 21 | "Evolve", 22 | "Double Tap", 23 | "Draw Card", 24 | "Demon Form", 25 | "Thorns", 26 | "Wraith Form v2", 27 | "Angry", 28 | "Time Warp", 29 | "IntangiblePlayer", 30 | "Ritual", 31 | "Intangible", 32 | "Flex", 33 | "Thievery", 34 | "Metallicize", 35 | "Fading", 36 | "BackAttack", 37 | "Tools Of The Trade", 38 | "Unawakened", 39 | "Entangled", 40 | "Plated Armor", 41 | "Regenerate", 42 | "Choked", 43 | "Storm", 44 | "Constricted", 45 | "Feel No Pain", 46 | "Dexterity", 47 | "CorpseExplosionPower", 48 | "Stasis", 49 | "Panache", 50 | "Magnetism", 51 | "After Image", 52 | "Nullify Attack", 53 | "RechargingCore", 54 | "GrowthPower", 55 | "Flight", 56 | "Rebound", 57 | "Confusion", 58 | "Mode Shift", 59 | "Curiosity", 60 | "Double Damage", 61 | "Artifact", 62 | "Berserk", 63 | "Amplify", 64 | "Anger", 65 | "Loop", 66 | "Creative AI", 67 | "Next Turn Block", 68 | "Generic Strength Up Power", 69 | "Rage", 70 | "Strength", 71 | "Curl Up", 72 | "Thousand Cuts", 73 | "Equilibrium", 74 | "Sadistic", 75 | "Painful Stabs", 76 | "Flame Barrier", 77 | "Mayhem", 78 | "Poison", 79 | "Regeneration", 80 | "Dark Embrace", 81 | "Skill Burn", 82 | "Weakened", 83 | "Life Link", 84 | "Draw Reduction", 85 | "Juggernaut", 86 | "Pen Nib", 87 | "Electro", 88 | "Fire Breathing", 89 | "Buffer", 90 | "Rupture", 91 | "Malleable", 92 | "Spore Cloud", 93 | "Vulnerable", 94 | "StrikeUp", 95 | "Night Terror", 96 | "Collect", 97 | "BlockReturnPower", 98 | "DevotionPower", 99 | "EnergyDownPower", 100 | "BattleHymn", 101 | "WrathNextTurnPower", 102 | "MasterRealityPower", 103 | "Vault", 104 | "DevaForm", 105 | "Controlled", 106 | "OmnisciencePower", 107 | "PathToVictoryPower", 108 | "Study", 109 | "FreeAttackPower", 110 | "OmegaPower", 111 | "AngelForm", 112 | "Adaptation", 113 | "Nirvana", 114 | "WireheadingPower", 115 | "CannotChangeStancePower", 116 | "NoSkills", 117 | "LikeWaterPower", 118 | "EndTurnDeath", 119 | "EstablishmentPower", 120 | "Vigor", 121 | "WaveOfTheHandPower", 122 | "Mantra", 123 | "Lockon", 124 | "Life Link", 125 | "TheBomb", 126 | "Compulsive", 127 | "Hello", 128 | "Winter", 129 | "Split", 130 | "Shifting", 131 | "Focus", 132 | "Phantasmal", 133 | "Attack Burn", 134 | "Minion", 135 | "Noxious Fumes", 136 | "Envenom", 137 | "Brutality", 138 | "NoBlockPower", 139 | "Burst", 140 | "EnergizedBlue", 141 | "Explosive", 142 | "Bias", 143 | "DuplicationPower", 144 | "Corruption", 145 | "StaticDischarge", 146 | "Slow", 147 | "Lightning Mastery", 148 | "Draw", 149 | "Combust", 150 | "Hex", 151 | "Frail", 152 | "Surrounded", 153 | "Heatsink", 154 | "TimeMazePower", 155 | "DexLoss", 156 | "Shackled", 157 | "Retain Cards", 158 | "Echo Form", 159 | "Energized", 160 | "Repair", 161 | "No Draw", 162 | "Invincible", 163 | "FlickPower", 164 | "AlwaysMad", 165 | "HotHot", 166 | "MasterRealityPower", 167 | "FlowPower", 168 | "DisciplinePower", 169 | "DEPRECATEDCondense", 170 | "EmotionalTurmoilPower", 171 | "Grounded", 172 | "Retribution", 173 | "Serenity", 174 | "Mastery", 175 | "Barricade", 176 | "Blur", 177 | "BeatOfDeath", 178 | "Infinite Blades", 179 | "Accuracy", 180 | ] 181 | # Wiki seems to list 108 buffs and debuffs, I may have missed a few 182 | NUM_EFFECTS = len(ALL_EFFECTS) 183 | 184 | ALL_INTENTS = [ 185 | "NONE", 186 | "ATTACK", 187 | "ATTACK_BUFF", 188 | "ATTACK_DEBUFF", 189 | "ATTACK_DEFEND", 190 | "BUFF", 191 | "DEBUFF", 192 | "STRONG_DEBUFF", 193 | "DEBUG", 194 | "DEFEND", 195 | "DEFEND_DEBUFF", 196 | "DEFEND_BUFF", 197 | "ESCAPE", 198 | "MAGIC", 199 | "SLEEP", 200 | "STUN", 201 | "UNKNOWN", 202 | ] 203 | NUM_INTENTS = len(ALL_INTENTS) 204 | 205 | ALL_MONSTER_TYPES = [ 206 | "NONE", 207 | "GremlinNob", 208 | "GremlinTsundere", 209 | "FungiBeast", 210 | "GremlinThief", 211 | "TheGuardian", 212 | "FuzzyLouseNormal", 213 | "GremlinWarrior", 214 | "Looter", 215 | "Lagavulin", 216 | "AcidSlime_L", 217 | "HexaghostOrb", 218 | "Hexaghost", 219 | "SlaverBlue", 220 | "Sentry", 221 | "AcidSlime_S", 222 | "SpikeSlime_S", 223 | "GremlinWizard", 224 | "FuzzyLouseDefensive", 225 | "SpikeSlime_M", 226 | "AcidSlime_M", 227 | "Cultist", 228 | "Apology Slime", 229 | "SlimeBoss", 230 | "HexaghostBody", 231 | "SpikeSlime_L", 232 | "GremlinFat", 233 | "SlaverRed", 234 | "JawWorm", 235 | "BronzeOrb", 236 | "BookOfStabbing", 237 | "TheCollector", 238 | "Snecko", 239 | "BanditBear", 240 | "SlaverBoss", 241 | "TorchHead", 242 | "Shelled Parasite", 243 | "Centurion", 244 | "Chosen", 245 | "BronzeAutomaton", 246 | "Healer", 247 | "BanditChild", 248 | "BanditLeader", 249 | "SphericGuardian", 250 | "SnakePlant", 251 | "Champ", 252 | "Mugger", 253 | "Byrd", 254 | "GremlinLeader", 255 | "Serpent", 256 | "Darkling", 257 | "Orb Walker", 258 | "Donu", 259 | "Maw", 260 | "Spiker", 261 | "AwakenedOne", 262 | "TimeEater", 263 | "Repulsor", 264 | "WrithingMass", 265 | "Deca", 266 | "Exploder", 267 | "Reptomancer", 268 | "Transient", 269 | "Nemesis", 270 | "Dagger", 271 | "GiantHead", 272 | "SpireShield", 273 | "SpireSpear", 274 | "CorruptHeart", 275 | ] 276 | NUM_MONSTER_TYPES = len(ALL_MONSTER_TYPES) 277 | 278 | ALL_ORBS = [ 279 | "NONE", # Indicates the slot does not exist 280 | "Empty", 281 | "Dark", 282 | "Frost", 283 | "Lightning", 284 | "Plasma", 285 | ] 286 | NUM_ORBS = len(ALL_ORBS) 287 | MAX_ORB_SLOTS = 10 288 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/observations.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import functools 4 | from typing import Union 5 | 6 | import numpy as np 7 | from gymnasium import spaces 8 | from pydantic import BaseModel 9 | 10 | from gym_sts.spaces import actions 11 | from gym_sts.spaces.constants.base import ScreenType 12 | 13 | from . import components 14 | 15 | 16 | class ObservationError(Exception): 17 | pass 18 | 19 | 20 | OBSERVATION_SPACE = spaces.Dict( 21 | { 22 | "persistent_state": components.PersistentStateObs.space(), 23 | "combat_state": components.CombatObs.space(), 24 | "shop_state": components.ShopObs.space(), 25 | "campfire_state": components.CampfireObs.space(), 26 | "card_reward_state": components.CardRewardObs.space(), 27 | "combat_reward_state": components.CombatRewardObs.space(), 28 | "event_state": components.EventStateObs.space(), 29 | "valid_action_mask": spaces.MultiBinary(len(actions.ACTIONS)), 30 | } 31 | ) 32 | 33 | 34 | class Observation: 35 | class SerializedState(BaseModel): 36 | campfire_state: components.CampfireObs.SerializedState 37 | card_reward_state: components.CardRewardObs.SerializedState 38 | combat_state: components.CombatObs.SerializedState 39 | combat_reward_state: components.CombatRewardObs.SerializedState 40 | persistent_state: components.PersistentStateObs.SerializedState 41 | shop_state: components.ShopObs.SerializedState 42 | 43 | def __init__(self, state: Union[dict, SerializedState]): 44 | if isinstance(state, dict): 45 | game_state = state.get("game_state", {}) 46 | screen_type = game_state.get("screen_type", ScreenType.NONE) 47 | screen_state = game_state.get("screen_state", {}) 48 | 49 | self.persistent_state = components.PersistentStateObs(**game_state) 50 | 51 | self.combat_state = components.CombatObs(game_state) 52 | self.combat_reward_state = components.CombatRewardObs(game_state) 53 | 54 | shop_state = screen_state if screen_type == ScreenType.SHOP_SCREEN else {} 55 | self.shop_state = components.ShopObs(**shop_state) 56 | 57 | campfire_state = screen_state if screen_type == ScreenType.REST else {} 58 | self.campfire_state = components.CampfireObs(**campfire_state) 59 | 60 | card_reward_state = ( 61 | screen_state if screen_type == ScreenType.CARD_REWARD else {} 62 | ) 63 | self.card_reward_state = components.CardRewardObs(**card_reward_state) 64 | 65 | self.event_state = components.EventStateObs(state) 66 | 67 | # Keep a reference to the raw CommunicationMod response 68 | self.state = state 69 | else: 70 | self.campfire_state = components.CampfireObs.deserialize( 71 | state.campfire_state 72 | ) 73 | self.card_reward_state = components.CardRewardObs.deserialize( 74 | state.card_reward_state 75 | ) 76 | self.combat_state = components.CombatObs.deserialize(state.combat_state) 77 | self.combat_reward_state = components.CombatRewardObs.deserialize( 78 | state.combat_reward_state 79 | ) 80 | self.persistent_state = components.PersistentStateObs.deserialize( 81 | state.persistent_state 82 | ) 83 | self.shop_state = components.ShopObs.deserialize(state.shop_state) 84 | 85 | # TODO this doesn't really work because we assume the keys will be present 86 | # replace with a pydantic model? 87 | self.state = {} 88 | 89 | @property 90 | def has_error(self) -> bool: 91 | return "error" in self.state 92 | 93 | def check_for_error(self) -> None: 94 | if self.has_error: 95 | raise ObservationError(self.state["error"]) 96 | 97 | @property 98 | def _available_commands(self) -> list[str]: 99 | self.check_for_error() 100 | return self.state["available_commands"] 101 | 102 | @property 103 | def choice_list(self) -> list[str]: 104 | self.check_for_error() 105 | if "choose" not in self._available_commands: 106 | return [] 107 | 108 | game_state = self.state.get("game_state") 109 | if game_state is None: 110 | return [] 111 | 112 | return game_state.get("choice_list", []) 113 | 114 | @property 115 | def game_over(self) -> bool: 116 | self.check_for_error() 117 | return self.screen_type == "GAME_OVER" 118 | 119 | @property 120 | def in_combat(self) -> bool: 121 | self.check_for_error() 122 | if "game_state" not in self.state: 123 | return False 124 | 125 | return "combat_state" in self.state["game_state"] 126 | 127 | @property 128 | def in_game(self) -> bool: 129 | self.check_for_error() 130 | return self.state["in_game"] 131 | 132 | @property 133 | def screen_type(self) -> str: 134 | self.check_for_error() 135 | if "game_state" in self.state: 136 | game_state = self.state["game_state"] 137 | screen_type = game_state["screen_type"] 138 | else: 139 | # CommunicationMod doesn't specify a screen type in the main menu 140 | screen_type = "MAIN_MENU" 141 | 142 | return screen_type 143 | 144 | @property 145 | def stable(self) -> bool: 146 | return self.state["ready_for_command"] 147 | 148 | @functools.cached_property 149 | def valid_actions(self) -> list[actions.Action]: 150 | # avoid circular import 151 | from gym_sts.envs.action_validation import get_valid 152 | 153 | return get_valid(self) 154 | 155 | def serialize(self) -> dict: 156 | valid_action_mask = np.zeros([len(actions.ACTIONS)], dtype=bool) 157 | for action in self.valid_actions: 158 | valid_action_mask[action._id] = True 159 | 160 | return { 161 | "persistent_state": self.persistent_state.serialize(), 162 | "combat_state": self.combat_state.serialize(), 163 | "shop_state": self.shop_state.serialize(), 164 | "campfire_state": self.campfire_state.serialize(), 165 | "card_reward_state": self.card_reward_state.serialize(), 166 | "combat_reward_state": self.combat_reward_state.serialize(), 167 | "event_state": self.event_state.serialize(), 168 | "valid_action_mask": valid_action_mask, 169 | } 170 | 171 | @classmethod 172 | def deserialize(cls, raw_data: dict) -> Observation: 173 | data = cls.SerializedState(**raw_data) 174 | return cls(data) 175 | -------------------------------------------------------------------------------- /gym_sts/rl/mvp.py: -------------------------------------------------------------------------------- 1 | """Run with rllib.""" 2 | import os 3 | 4 | import fancyflags as ff 5 | import ray 6 | from absl import app, logging 7 | from gymnasium import spaces 8 | from ray import tune 9 | from ray.air import config 10 | from ray.air.integrations.wandb import WandbLoggerCallback 11 | from ray.rllib.algorithms import ppo 12 | from ray.rllib.models import preprocessors 13 | from ray.train.rl import RLTrainer 14 | 15 | from gym_sts.envs import base, single_combat 16 | from gym_sts.rl import action_masking 17 | from gym_sts.rl.metrics import StSCustomMetricCallbacks 18 | 19 | 20 | def check_rllib_bug(space: spaces.Space): 21 | # rllib special-cases certain spaces, which we don't want 22 | if isinstance(space, spaces.Dict): 23 | for subspace in space.values(): 24 | check_rllib_bug(subspace) 25 | elif not isinstance(space, (spaces.Discrete, spaces.MultiDiscrete)): 26 | assert space.shape != preprocessors.ATARI_RAM_OBS_SHAPE 27 | 28 | 29 | check_rllib_bug(base.OBSERVATION_SPACE) 30 | 31 | action_masking.register() 32 | 33 | ENV = ff.DEFINE_dict( 34 | "env", 35 | lib=ff.String("lib"), 36 | mods=ff.String("mods"), 37 | out=ff.String(None), 38 | headless=ff.Boolean(True), 39 | animate=ff.Boolean(False), 40 | ascension=ff.Integer(20), 41 | build_image=ff.Boolean(False), 42 | reboot_frequency=ff.Integer(50, "Reboot game every n resets."), 43 | reboot_on_error=ff.Boolean(False), 44 | log_states=ff.Boolean(False), 45 | ) 46 | 47 | TUNE = ff.DEFINE_dict( 48 | "tune", 49 | run=dict( 50 | name=ff.String("sts-rl", "Name of the ray experiment"), 51 | local_dir=ff.String(None), # default is ~/ray_results/ 52 | verbose=ff.Integer(3), 53 | ), 54 | failure_config=dict( 55 | max_failures=ff.Integer(0) # Set to -1 to enable infinite recovery retries 56 | ), 57 | checkpoint_config=dict( 58 | checkpoint_frequency=ff.Integer(20), 59 | checkpoint_at_end=ff.Boolean(False), 60 | num_to_keep=ff.Integer(3), 61 | ), 62 | restore=ff.String( 63 | None, "Path to experiment directory to restore from, e.g. ~/ray_results/sts-rl" 64 | ), 65 | sync_config=dict( 66 | upload_dir=ff.String(None, "Path to local or remote folder."), 67 | syncer=ff.String("auto"), 68 | sync_on_checkpoint=ff.Boolean(True), 69 | sync_period=ff.Integer(300), 70 | ), 71 | ) 72 | 73 | WANDB = ff.DEFINE_dict( 74 | "wandb", 75 | use=ff.Boolean(False), 76 | entity=ff.String("sts-ai"), 77 | project=ff.String("sts-rllib"), 78 | api_key_file=ff.String(None), 79 | api_key=ff.String(None), 80 | log_config=ff.Boolean(False), 81 | upload_checkpoints=ff.Boolean(False), 82 | ) 83 | 84 | RL = ff.DEFINE_dict( 85 | "rl", 86 | rollout_fragment_length=ff.Integer(32), 87 | train_batch_size=ff.Integer(1024), 88 | num_workers=ff.Integer(0), 89 | model=dict( 90 | custom_model=ff.String("masked"), 91 | fcnet_hiddens=ff.Sequence([256, 256, 256, 256]), 92 | fcnet_activation=ff.String("relu"), 93 | ), 94 | entropy_coeff=ff.Float(0.0), 95 | ) 96 | 97 | SCALING = ff.DEFINE_dict( 98 | "scaling", 99 | num_workers=ff.Integer(0), 100 | use_gpu=ff.Boolean(False), 101 | trainer_resources=dict(CPU=ff.Integer(1), GPU=ff.Integer(0)), 102 | resources_per_worker=dict(CPU=ff.Integer(1), GPU=ff.Integer(0)), 103 | ) 104 | 105 | SINGLE_COMBAT = ff.DEFINE_dict( 106 | "single_combat", 107 | use=ff.Boolean(False), 108 | enemies=ff.StringList(["3_Sentries"]), 109 | cards=ff.StringList(["Strike_B"] * 4 + ["Defend_B"] * 4 + ["Zap"] + ["Dualcast"]), 110 | add_relics=ff.StringList([]), 111 | ) 112 | 113 | 114 | class Env(base.SlayTheSpireGymEnv): 115 | def __init__(self, cfg: dict): 116 | super().__init__(**cfg) 117 | 118 | 119 | class SingleCombatEnv(single_combat.SingleCombatSTSEnv): 120 | def __init__(self, cfg: dict): 121 | super().__init__(**cfg) 122 | 123 | 124 | def main(_): 125 | ray.init(address=None) 126 | # we need abspath's here because the cwd will be different later 127 | output_dir = ENV.value["out"] 128 | if output_dir is not None: 129 | output_dir = os.path.abspath(output_dir) 130 | 131 | env_config = { 132 | "lib_dir": os.path.abspath(ENV.value["lib"]), 133 | "mods_dir": os.path.abspath(ENV.value["mods"]), 134 | "output_dir": output_dir, 135 | } 136 | for key in [ 137 | "headless", 138 | "animate", 139 | "reboot_frequency", 140 | "reboot_on_error", 141 | "ascension", 142 | "log_states", 143 | ]: 144 | env_config[key] = ENV.value[key] 145 | 146 | if SINGLE_COMBAT.value["use"]: 147 | env_config["enemies"] = SINGLE_COMBAT.value["enemies"] 148 | env_config["cards"] = SINGLE_COMBAT.value["cards"] 149 | env_config["add_relics"] = SINGLE_COMBAT.value["add_relics"] 150 | 151 | if ENV.value["build_image"]: 152 | logging.info("build_image") 153 | base.SlayTheSpireGymEnv.build_image() 154 | 155 | rl_config = RL.value.copy() 156 | 157 | ppo_config = { 158 | "env": SingleCombatEnv if SINGLE_COMBAT.value["use"] else Env, 159 | "env_config": env_config, 160 | "framework": "tf2", 161 | "eager_tracing": True, 162 | # "horizon": 64, # just for reporting some rewards 163 | # "soft_horizon": True, 164 | # "no_done_at_end": True, 165 | } 166 | if SINGLE_COMBAT.value["use"]: 167 | ppo_config["callbacks"] = StSCustomMetricCallbacks 168 | 169 | ppo_config.update(rl_config) 170 | 171 | trainer = RLTrainer( 172 | scaling_config=config.ScalingConfig(**SCALING.value), 173 | algorithm=ppo.PPO, 174 | config=ppo_config, 175 | ) 176 | 177 | callbacks = [] 178 | wandb_config = WANDB.value.copy() 179 | if wandb_config.pop("use"): 180 | wandb_callback = WandbLoggerCallback( 181 | name=TUNE.value["run"]["name"], **wandb_config 182 | ) 183 | callbacks.append(wandb_callback) 184 | 185 | tune_config = TUNE.value 186 | # We're doing a lot of direct key-based access of values in these dict flags. 187 | # The fancyflags docs consider this an antipattern, see: 188 | # https://github.com/deepmind/fancyflags#tips. 189 | sync_config = tune.SyncConfig(**tune_config["sync_config"]) 190 | checkpoint_config = config.CheckpointConfig(**tune_config["checkpoint_config"]) 191 | failure_config = config.FailureConfig(**tune_config["failure_config"]) 192 | run_config = config.RunConfig( 193 | callbacks=callbacks, 194 | checkpoint_config=checkpoint_config, 195 | sync_config=sync_config, 196 | failure_config=failure_config, 197 | **tune_config["run"], 198 | ) 199 | 200 | tuner = tune.Tuner( 201 | trainable=trainer, 202 | run_config=run_config, 203 | ) 204 | 205 | restore_path = tune_config.get("restore") 206 | if restore_path: 207 | tuner = tune.Tuner.restore(restore_path, trainable=trainer, resume_errored=True) 208 | 209 | tuner.fit() 210 | 211 | 212 | if __name__ == "__main__": 213 | app.run(main) 214 | -------------------------------------------------------------------------------- /gym_sts/envs/action_validation.py: -------------------------------------------------------------------------------- 1 | from gym_sts.spaces import actions 2 | from gym_sts.spaces.constants.base import ScreenType 3 | from gym_sts.spaces.observations import Observation 4 | 5 | 6 | def validate_end_turn(action: actions.EndTurn, observation: Observation) -> bool: 7 | return "end" in observation._available_commands 8 | 9 | 10 | def validate_return(action: actions.Return, observation: Observation) -> bool: 11 | # Once the agent proceeds to the map, we don't allow it to swap back 12 | # to the previous screen. This prevents the agent from looping back and 13 | # forth between the two screens for umpteen steps. 14 | if observation.screen_type == ScreenType.MAP: 15 | return False 16 | 17 | for word in ["cancel", "leave", "return", "skip"]: 18 | if word in observation._available_commands: 19 | return True 20 | 21 | return False 22 | 23 | 24 | def validate_proceed(action: actions.Proceed, observation: Observation) -> bool: 25 | for word in ["confirm", "proceed"]: 26 | if word in observation._available_commands: 27 | return True 28 | 29 | return False 30 | 31 | 32 | def _validate_choice(action: actions.Choose, observation: Observation) -> bool: 33 | return action.choice_index < len(observation.choice_list) 34 | 35 | 36 | def validate_choose(action: actions.Choose, observation: Observation) -> bool: 37 | if "choose" not in observation._available_commands: 38 | return False 39 | 40 | if observation.in_combat: 41 | if observation.screen_type in ["CARD_REWARD", "GRID", "HAND_SELECT"]: 42 | return _validate_choice(action, observation) 43 | else: 44 | # TODO determine if there are any other choices that could 45 | # be made mid-combat, such as picking from deck/discard/exhaust, 46 | # or scrying. 47 | print("NOT IMPLEMENTED") 48 | return False 49 | elif observation.screen_type in [ 50 | "BOSS_REWARD", 51 | "CARD_REWARD", 52 | "CHEST", 53 | "COMBAT_REWARD", 54 | "EVENT", 55 | "GRID", 56 | "MAP", 57 | "REST", 58 | "SHOP_ROOM", 59 | "SHOP_SCREEN", 60 | ]: 61 | return _validate_choice(action, observation) 62 | else: 63 | # TODO handle choices outside of combat, like events 64 | print("NOT IMPLEMENTED") 65 | return True 66 | 67 | 68 | def validate_play(action: actions.PlayCard, observation: Observation) -> bool: 69 | if "play" not in observation._available_commands: 70 | return False 71 | 72 | if not observation.in_combat or observation.screen_type != "NONE": 73 | return False 74 | 75 | # Choices correspond to playing cards 76 | hand = observation.combat_state.hand 77 | index = action.card_position 78 | # Adjust to account for CommunicationMod's odd indexing scheme. 79 | index -= 1 80 | if index < 0: 81 | index += 10 82 | 83 | if index >= len(hand): 84 | return False 85 | 86 | card = hand[index] 87 | 88 | target_index = action.target_index 89 | 90 | # Technically it should be invalid to specify a target if the card 91 | # doesn't take a target (and this would cut down on the number of valid 92 | # actions), but the game simply ignores the target choice, so it's not an 93 | # error. Because we only want actions to be invalid if the game truly won't 94 | # accept them, we've commented this validation check out for now. 95 | # if target_index is not None and not card.has_target: 96 | # return False 97 | 98 | if target_index is None and card.has_target: 99 | return False 100 | 101 | enemies = observation.combat_state.enemies 102 | if target_index is not None: 103 | # Even if the card doesn't take a target, STS still requires the stated 104 | # target index to be in bounds. 105 | if target_index >= len(enemies): 106 | return False 107 | 108 | return card.is_playable 109 | 110 | 111 | def _validate_potion( 112 | action: actions.PotionAction, observation: Observation, prop: str 113 | ) -> bool: 114 | if "potion" not in observation._available_commands: 115 | return False 116 | 117 | index = action.potion_index 118 | potions = observation.persistent_state.potions 119 | if index >= len(potions): 120 | return False 121 | 122 | potion = potions[index] 123 | return getattr(potion, prop) 124 | 125 | 126 | def validate_use_potion(action: actions.UsePotion, observation: Observation) -> bool: 127 | if not _validate_potion(action, observation, "can_use"): 128 | return False 129 | 130 | index = action.potion_index 131 | potions = observation.persistent_state.potions 132 | potion = potions[index] 133 | 134 | target_index = action.target_index 135 | 136 | # Technically it should be invalid to specify a target if the card 137 | # doesn't take a target (and this would cut down on the number of valid 138 | # actions), but the game simply ignores the target choice, so it's not an 139 | # error. Because we only want actions to be invalid if the game truly won't 140 | # accept them, we've commented this validation check out for now. 141 | # if target_index is not None and not potion.requires_target: 142 | # return False 143 | 144 | if potion.requires_target: 145 | # Explosive Potion is basically incorrectly defined within STS. 146 | # It doesn't actually require a target. 147 | if potion.id == "Explosive Potion": 148 | return True 149 | 150 | if target_index is None: 151 | return False 152 | 153 | # Unlike when playing cards, STS disregards out-of-range target indices 154 | # when using potions that don't take a target. 155 | enemies = observation.combat_state.enemies 156 | if target_index >= len(enemies): 157 | return False 158 | 159 | return True 160 | 161 | 162 | def validate_discard_potion( 163 | action: actions.DiscardPotion, observation: Observation 164 | ) -> bool: 165 | return _validate_potion(action, observation, "can_discard") 166 | 167 | 168 | def validate(action: actions.Action, observation: Observation) -> bool: 169 | if isinstance(action, actions.EndTurn): 170 | return validate_end_turn(action, observation) 171 | 172 | elif isinstance(action, actions.Return): 173 | return validate_return(action, observation) 174 | 175 | elif isinstance(action, actions.Proceed): 176 | return validate_proceed(action, observation) 177 | 178 | elif isinstance(action, actions.Choose): 179 | return validate_choose(action, observation) 180 | 181 | elif isinstance(action, actions.UsePotion): 182 | return validate_use_potion(action, observation) 183 | 184 | elif isinstance(action, actions.DiscardPotion): 185 | return validate_discard_potion(action, observation) 186 | 187 | elif isinstance(action, actions.PlayCard): 188 | return validate_play(action, observation) 189 | 190 | raise ValueError("Unrecognized action type") 191 | 192 | 193 | def get_valid(observation: Observation): 194 | # Note: this method is rather inefficient. We could instead generate the 195 | # valid actions from the observation. 196 | return [a for a in actions.ACTIONS if validate(a, observation)] 197 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/components/persistent.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union 4 | 5 | import numpy as np 6 | from gymnasium.spaces import Dict, Discrete, MultiBinary, Tuple 7 | from pydantic import BaseModel, Field, root_validator, validator 8 | 9 | import gym_sts.spaces.constants.base as base_consts 10 | import gym_sts.spaces.constants.potions as potion_consts 11 | import gym_sts.spaces.constants.relics as relic_consts 12 | from gym_sts.spaces.constants.cards import CardCatalog, CardMetadata 13 | from gym_sts.spaces.observations import serializers, spaces, types, utils 14 | 15 | from .base import PydanticComponent 16 | 17 | 18 | class PersistentStateObs(PydanticComponent): 19 | floor: int = 0 20 | hp: int = Field(0, alias="current_hp") 21 | max_hp: int = 0 22 | gold: int = 0 23 | potions: list[types.Potion] = [] 24 | relics: list[types.Relic] = [] 25 | deck: list[types.Card] = [] 26 | keys: types.Keys = types.Keys() 27 | map: types.Map = types.Map() 28 | screen_type: base_consts.ScreenType = base_consts.ScreenType.EMPTY 29 | 30 | @root_validator(pre=True) 31 | def combine_map_inputs(cls, values): 32 | """ 33 | CommunicationMod provides the map nodes and act boss separately, but we'd 34 | rather combine them into one Pydantic model. To do this, we restructure the 35 | input so our Pydantic model will deserialize it properly. 36 | """ 37 | 38 | try: 39 | map = values["map"] 40 | if not isinstance(map, types.Map): 41 | restructured_map = { 42 | "nodes": map, 43 | "boss": values["act_boss"], 44 | } 45 | values["map"] = restructured_map 46 | except KeyError: 47 | pass 48 | 49 | return values 50 | 51 | @validator("deck") 52 | def ensure_deck_sorted(cls, v: list[types.Card]) -> list[types.Card]: 53 | v.sort() 54 | return v 55 | 56 | @validator("relics") 57 | def ensure_relics_sorted(cls, v: list[types.Relic]) -> list[types.Relic]: 58 | v.sort() 59 | return v 60 | 61 | @staticmethod 62 | def space(): 63 | return Dict( 64 | { 65 | "floor": MultiBinary(base_consts.LOG_NUM_FLOORS), 66 | "health": spaces.generate_health_space(), 67 | "gold": MultiBinary(base_consts.LOG_MAX_GOLD), 68 | "potions": Tuple( 69 | [types.Potion.space()] * potion_consts.NUM_POTION_SLOTS 70 | ), 71 | "relics": Tuple( 72 | [MultiBinary(relic_consts.LOG_MAX_COUNTER)] 73 | * relic_consts.NUM_RELICS 74 | ), 75 | "deck": spaces.generate_card_space(), 76 | "keys": MultiBinary(base_consts.NUM_KEYS), 77 | "map": types.Map.space(), 78 | "screen_type": Discrete(len(base_consts.ScreenType.__members__)), 79 | } 80 | ) 81 | 82 | def serialize(self) -> dict: 83 | floor = utils.to_binary_array(self.floor, base_consts.LOG_NUM_FLOORS) 84 | health = types.Health(hp=self.hp, max_hp=self.max_hp).serialize() 85 | gold = utils.to_binary_array(self.gold, base_consts.LOG_MAX_GOLD) 86 | 87 | potions = [types.Potion.serialize_empty()] * potion_consts.NUM_POTION_SLOTS 88 | 89 | for i, potion in enumerate(self.potions): 90 | potions[i] = potion.serialize() 91 | 92 | relics = [ 93 | np.zeros(relic_consts.LOG_MAX_COUNTER, dtype=bool) 94 | ] * relic_consts.NUM_RELICS 95 | 96 | for relic in self.relics: 97 | ser = relic.serialize() 98 | relics[ser["id"]] = ser["counter"] 99 | 100 | deck = serializers.serialize_cards(self.deck) 101 | 102 | keys = self.keys.serialize() 103 | map = self.map.serialize() 104 | 105 | response = { 106 | "floor": floor, 107 | "health": health, 108 | "gold": gold, 109 | "potions": potions, 110 | "relics": relics, 111 | "deck": deck, 112 | "keys": keys, 113 | "map": map, 114 | "screen_type": list(base_consts.ScreenType.__members__).index( 115 | self.screen_type.value 116 | ), 117 | } 118 | 119 | return response 120 | 121 | class SerializedState(BaseModel): 122 | floor: types.BinaryArray 123 | health: types.Health.SerializedState 124 | gold: types.BinaryArray 125 | potions: list[types.Potion.SerializedState] 126 | relics: list[types.BinaryArray] 127 | deck: types.BinaryArray 128 | keys: types.BinaryArray 129 | map: types.Map.SerializedState 130 | screen_type: int 131 | 132 | class Config: 133 | arbitrary_types_allowed = True 134 | 135 | @classmethod 136 | def deserialize(cls, data: Union[dict, SerializedState]) -> PersistentStateObs: 137 | if not isinstance(data, cls.SerializedState): 138 | data = cls.SerializedState(**data) 139 | 140 | floor = utils.from_binary_array(data.floor) 141 | hp = utils.from_binary_array(data.health.hp) 142 | max_hp = utils.from_binary_array(data.health.max_hp) 143 | gold = utils.from_binary_array(data.gold) 144 | 145 | potions = [] 146 | for p in data.potions: 147 | potion = types.Potion.deserialize(p) 148 | if potion.id != "NONE": 149 | potions.append(potion) 150 | 151 | relics = [] 152 | for idx, r in enumerate(data.relics): 153 | relic_data = { 154 | "id": idx, 155 | "counter": r, 156 | } 157 | relic = types.Relic.deserialize(relic_data) 158 | if relic.id != relic_consts.RelicCatalog.NONE.id and relic.counter > 0: 159 | relics.append(relic) 160 | 161 | deck = [] 162 | for _card_idx, count in enumerate(data.deck): 163 | card_idx, upgrade_bit = divmod(_card_idx, 2) 164 | card_id = CardCatalog.ids[card_idx] 165 | if card_id != CardCatalog.NONE.id and count > 0: 166 | card_meta: CardMetadata = getattr(CardCatalog, card_id) 167 | card_props = card_meta.upgraded if upgrade_bit else card_meta.unupgraded 168 | 169 | for _ in range(count): 170 | card = types.Card( 171 | id=card_id, 172 | name=card_meta.name, 173 | # TODO may be wrong because we don't currently serialize cost 174 | cost=card_props.default_cost, 175 | exhausts=card_props.exhausts, 176 | ethereal=card_props.ethereal, 177 | has_target=card_props.has_target, 178 | # TODO may be wrong for cards that can be upgraded 2+ times 179 | upgrades=upgrade_bit, 180 | ) 181 | deck.append(card) 182 | 183 | keys = types.Keys.deserialize(data.keys) 184 | map = types.Map.deserialize(data.map) 185 | screen_type_str = list(base_consts.ScreenType.__members__)[data.screen_type] 186 | screen_type = base_consts.ScreenType(screen_type_str) 187 | 188 | return cls( 189 | floor=floor, 190 | current_hp=hp, 191 | max_hp=max_hp, 192 | gold=gold, 193 | potions=potions, 194 | relics=relics, 195 | deck=deck, 196 | keys=keys, 197 | map=map, 198 | screen_type=screen_type, 199 | ) 200 | -------------------------------------------------------------------------------- /gym_sts/spaces/constants/potions.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class PotionMetadata(BaseModel): 7 | id: str 8 | name: str 9 | requires_target: bool 10 | 11 | 12 | class _PotionCatalog: 13 | _id_to_meta = { 14 | # Indicates the potion slot does not exist 15 | "NONE": PotionMetadata( 16 | id="NONE", 17 | name="NONE", 18 | requires_target=False, 19 | ), 20 | "Potion Slot": PotionMetadata( 21 | id="Potion Slot", 22 | name="Potion Slot", 23 | requires_target=False, 24 | ), 25 | "EntropicBrew": PotionMetadata( 26 | id="EntropicBrew", 27 | name="Entropic Brew", 28 | requires_target=False, 29 | ), 30 | "Regen Potion": PotionMetadata( 31 | id="Regen Potion", 32 | name="Regen Potion", 33 | requires_target=False, 34 | ), 35 | "AttackPotion": PotionMetadata( 36 | id="AttackPotion", 37 | name="Attack Potion", 38 | requires_target=False, 39 | ), 40 | "SkillPotion": PotionMetadata( 41 | id="SkillPotion", 42 | name="Skill Potion", 43 | requires_target=False, 44 | ), 45 | "LiquidMemories": PotionMetadata( 46 | id="LiquidMemories", 47 | name="Liquid Memories", 48 | requires_target=False, 49 | ), 50 | "SteroidPotion": PotionMetadata( 51 | id="SteroidPotion", 52 | name="Flex Potion", 53 | requires_target=False, 54 | ), 55 | "FairyPotion": PotionMetadata( 56 | id="FairyPotion", 57 | name="Fairy in a Bottle", 58 | requires_target=False, 59 | ), 60 | "Energy Potion": PotionMetadata( 61 | id="Energy Potion", 62 | name="Energy Potion", 63 | requires_target=False, 64 | ), 65 | "EssenceOfSteel": PotionMetadata( 66 | id="EssenceOfSteel", 67 | name="Essence of Steel", 68 | requires_target=False, 69 | ), 70 | "GamblersBrew": PotionMetadata( 71 | id="GamblersBrew", 72 | name="Gambler's Brew", 73 | requires_target=False, 74 | ), 75 | "PowerPotion": PotionMetadata( 76 | id="PowerPotion", 77 | name="Power Potion", 78 | requires_target=False, 79 | ), 80 | "PotionOfCapacity": PotionMetadata( 81 | id="PotionOfCapacity", 82 | name="Potion of Capacity", 83 | requires_target=False, 84 | ), 85 | "DuplicationPotion": PotionMetadata( 86 | id="DuplicationPotion", 87 | name="Duplication Potion", 88 | requires_target=False, 89 | ), 90 | "BlessingOfTheForge": PotionMetadata( 91 | id="BlessingOfTheForge", 92 | name="Blessing of the Forge", 93 | requires_target=False, 94 | ), 95 | "Swift Potion": PotionMetadata( 96 | id="Swift Potion", 97 | name="Swift Potion", 98 | requires_target=False, 99 | ), 100 | "CultistPotion": PotionMetadata( 101 | id="CultistPotion", 102 | name="Cultist Potion", 103 | requires_target=False, 104 | ), 105 | "StancePotion": PotionMetadata( 106 | id="StancePotion", 107 | name="Stance Potion", 108 | requires_target=False, 109 | ), 110 | "Fruit Juice": PotionMetadata( 111 | id="Fruit Juice", 112 | name="Fruit Juice", 113 | requires_target=False, 114 | ), 115 | "LiquidBronze": PotionMetadata( 116 | id="LiquidBronze", 117 | name="Liquid Bronze", 118 | requires_target=False, 119 | ), 120 | "HeartOfIron": PotionMetadata( 121 | id="HeartOfIron", 122 | name="Heart of Iron", 123 | requires_target=False, 124 | ), 125 | "Fire Potion": PotionMetadata( 126 | id="Fire Potion", 127 | name="Fire Potion", 128 | requires_target=True, 129 | ), 130 | "Ancient Potion": PotionMetadata( 131 | id="Ancient Potion", 132 | name="Ancient Potion", 133 | requires_target=False, 134 | ), 135 | "SmokeBomb": PotionMetadata( 136 | id="SmokeBomb", 137 | name="Smoke Bomb", 138 | requires_target=True, 139 | ), 140 | "Block Potion": PotionMetadata( 141 | id="Block Potion", 142 | name="Block Potion", 143 | requires_target=False, 144 | ), 145 | "BottledMiracle": PotionMetadata( 146 | id="BottledMiracle", 147 | name="Bottled Miracle", 148 | requires_target=False, 149 | ), 150 | "CunningPotion": PotionMetadata( 151 | id="CunningPotion", 152 | name="Cunning Potion", 153 | requires_target=False, 154 | ), 155 | "FocusPotion": PotionMetadata( 156 | id="FocusPotion", 157 | name="Focus Potion", 158 | requires_target=False, 159 | ), 160 | "EssenceOfDarkness": PotionMetadata( 161 | id="EssenceOfDarkness", 162 | name="Essence of Darkness", 163 | requires_target=False, 164 | ), 165 | "GhostInAJar": PotionMetadata( 166 | id="GhostInAJar", 167 | name="Ghost in a Jar", 168 | requires_target=False, 169 | ), 170 | "Explosive Potion": PotionMetadata( 171 | id="Explosive Potion", 172 | name="Explosive Potion", 173 | requires_target=True, 174 | ), 175 | "FearPotion": PotionMetadata( 176 | id="FearPotion", 177 | name="Fear Potion", 178 | requires_target=True, 179 | ), 180 | "SneckoOil": PotionMetadata( 181 | id="SneckoOil", 182 | name="Snecko Oil", 183 | requires_target=False, 184 | ), 185 | "Poison Potion": PotionMetadata( 186 | id="Poison Potion", 187 | name="Poison Potion", 188 | requires_target=True, 189 | ), 190 | "DistilledChaos": PotionMetadata( 191 | id="DistilledChaos", 192 | name="Distilled Chaos", 193 | requires_target=False, 194 | ), 195 | "SpeedPotion": PotionMetadata( 196 | id="SpeedPotion", 197 | name="Speed Potion", 198 | requires_target=False, 199 | ), 200 | "Strength Potion": PotionMetadata( 201 | id="Strength Potion", 202 | name="Strength Potion", 203 | requires_target=False, 204 | ), 205 | "Weak Potion": PotionMetadata( 206 | id="Weak Potion", 207 | name="Weak Potion", 208 | requires_target=True, 209 | ), 210 | "Dexterity Potion": PotionMetadata( 211 | id="Dexterity Potion", 212 | name="Dexterity Potion", 213 | requires_target=False, 214 | ), 215 | "ColorlessPotion": PotionMetadata( 216 | id="ColorlessPotion", 217 | name="Colorless Potion", 218 | requires_target=False, 219 | ), 220 | "Ambrosia": PotionMetadata( 221 | id="Ambrosia", 222 | name="Ambrosia", 223 | requires_target=False, 224 | ), 225 | "BloodPotion": PotionMetadata( 226 | id="BloodPotion", 227 | name="Blood Potion", 228 | requires_target=False, 229 | ), 230 | "ElixirPotion": PotionMetadata( 231 | id="ElixirPotion", 232 | name="Elixer", 233 | requires_target=False, 234 | ), 235 | } 236 | 237 | def __getattr__(self, attr): 238 | data = self._id_to_meta.get(attr) 239 | 240 | if data is None: 241 | raise AttributeError 242 | 243 | return data 244 | 245 | def __len__(self) -> int: 246 | return len(self._id_to_meta) 247 | 248 | @property 249 | def ids(self) -> list: 250 | return list(self._id_to_meta.keys()) 251 | 252 | 253 | PotionCatalog = _PotionCatalog() 254 | 255 | 256 | NUM_POTIONS = len(PotionCatalog) 257 | LOG_NUM_POTIONS = math.ceil(math.log(NUM_POTIONS, 2)) 258 | 259 | NUM_POTION_SLOTS = 5 260 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/components/combat.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | from gymnasium.spaces import Dict, MultiBinary, MultiDiscrete, Tuple 8 | from pydantic import BaseModel 9 | 10 | import gym_sts.spaces.constants.base as base_consts 11 | import gym_sts.spaces.constants.combat as combat_consts 12 | from gym_sts.spaces.constants.cards import CardCatalog 13 | from gym_sts.spaces.observations import serializers, spaces, types, utils 14 | 15 | from .base import ObsComponent 16 | 17 | 18 | class CombatObs(ObsComponent): 19 | def __init__(self, state: dict): 20 | # Sane defaults 21 | self.turn = 0 22 | 23 | self.hand: list[types.HandCard] = [] 24 | self.discard: list[types.Card] = [] 25 | self.draw: list[types.Card] = [] 26 | self.exhaust: list[types.Card] = [] 27 | 28 | self.enemies: list[types.Enemy] = [] 29 | 30 | self.energy = 0 31 | self.block = 0 32 | self.effects: list[types.Effect] = [] 33 | self.orbs: list[types.Orb] = [] 34 | 35 | # TODO make selections part of observation space? 36 | self.hand_selects = [] 37 | self.max_selects = 0 38 | self.can_pick_zero = False 39 | 40 | if "combat_state" in state: 41 | combat_state = state["combat_state"] 42 | 43 | self.turn = combat_state["turn"] 44 | 45 | self.hand = [types.HandCard(**card) for card in combat_state["hand"]] 46 | 47 | self.discard = [types.Card(**card) for card in combat_state["discard_pile"]] 48 | self.discard.sort() 49 | 50 | # TODO what if we have frozen eye? Then the draw order matters. 51 | self.draw = [types.Card(**card) for card in combat_state["draw_pile"]] 52 | self.draw.sort() 53 | 54 | self.exhaust = [types.Card(**card) for card in combat_state["exhaust_pile"]] 55 | self.exhaust.sort() 56 | 57 | self.enemies = [types.Enemy(**enemy) for enemy in combat_state["monsters"]] 58 | assert len(self.enemies) <= combat_consts.MAX_NUM_ENEMIES 59 | 60 | player_state = combat_state["player"] 61 | self.block = player_state["block"] 62 | self.energy = player_state["energy"] 63 | self.effects = [types.Effect(**effect) for effect in player_state["powers"]] 64 | self.orbs = [types.Orb(**orb) for orb in player_state["orbs"]] 65 | 66 | if state["screen_type"] == base_consts.ScreenType.HAND_SELECT: 67 | screen_state = state["screen_state"] 68 | self.hand_selects = screen_state["selected"] 69 | self.max_selects = screen_state["max_cards"] 70 | self.can_pick_zero = screen_state["can_pick_zero"] 71 | 72 | @staticmethod 73 | def space() -> Dict: 74 | return Dict( 75 | { 76 | "turn": MultiBinary(combat_consts.LOG_MAX_TURN), 77 | "hand": Tuple([types.HandCard.space()] * combat_consts.MAX_HAND_SIZE), 78 | "energy": MultiBinary(combat_consts.LOG_MAX_ENERGY), 79 | "orbs": MultiDiscrete( 80 | [combat_consts.NUM_ORBS] * combat_consts.MAX_ORB_SLOTS 81 | ), 82 | "block": MultiBinary(combat_consts.LOG_MAX_BLOCK), 83 | "effects": spaces.generate_effect_space(), 84 | "enemies": Tuple([types.Enemy.space()] * combat_consts.MAX_NUM_ENEMIES), 85 | "discard": spaces.generate_card_space(), 86 | "draw": spaces.generate_card_space(), 87 | "exhaust": spaces.generate_card_space(), 88 | } 89 | ) 90 | 91 | def serialize(self) -> dict: 92 | turn = utils.to_binary_array(self.turn, combat_consts.LOG_MAX_TURN) 93 | energy = utils.to_binary_array(self.energy, combat_consts.LOG_MAX_ENERGY) 94 | block = utils.to_binary_array(self.block, combat_consts.LOG_MAX_BLOCK) 95 | 96 | hand = [types.HandCard.serialize_empty()] * combat_consts.MAX_HAND_SIZE 97 | for i, card in enumerate(self.hand): 98 | card_idx = card.serialize() 99 | hand[i] = card_idx 100 | 101 | effects = types.Effect.serialize_all(self.effects) 102 | orbs = serializers.serialize_orbs(self.orbs) 103 | 104 | enemies = [types.Enemy.serialize_empty()] * combat_consts.MAX_NUM_ENEMIES 105 | for i, enemy in enumerate(self.enemies): 106 | enemies[i] = enemy.serialize() 107 | 108 | discard = serializers.serialize_cards(self.discard) 109 | draw = serializers.serialize_cards(self.draw) 110 | exhaust = serializers.serialize_cards(self.exhaust) 111 | 112 | response = { 113 | "turn": turn, 114 | "hand": hand, 115 | "energy": energy, 116 | "block": block, 117 | "effects": effects, 118 | "orbs": orbs, 119 | "enemies": enemies, 120 | "discard": discard, 121 | "draw": draw, 122 | "exhaust": exhaust, 123 | } 124 | 125 | return response 126 | 127 | class SerializedState(BaseModel): 128 | turn: types.BinaryArray 129 | hand: list[types.HandCard.SerializedState] 130 | energy: types.BinaryArray 131 | block: types.BinaryArray 132 | effects: list[dict] 133 | orbs: npt.NDArray[np.uint] 134 | enemies: list[dict] 135 | discard: npt.NDArray[np.uint] 136 | draw: npt.NDArray[np.uint] 137 | exhaust: npt.NDArray[np.uint] 138 | 139 | class Config: 140 | arbitrary_types_allowed = True 141 | 142 | @classmethod 143 | def deserialize(cls, data: Union[dict, SerializedState]) -> CombatObs: 144 | if not isinstance(data, cls.SerializedState): 145 | data = cls.SerializedState(**data) 146 | 147 | # Instantiate with empty data and update attributes individually, 148 | # rather than trying to recreate CommunicationMod's weird data shape. 149 | instance = cls({}) 150 | 151 | instance.turn = utils.from_binary_array(data.turn) 152 | instance.energy = utils.from_binary_array(data.energy) 153 | instance.block = utils.from_binary_array(data.block) 154 | 155 | instance.effects = [] 156 | for effect_idx, e in enumerate(data.effects): 157 | effect = types.Effect.deserialize(e) 158 | if effect.amount != 0: 159 | effect.id = combat_consts.ALL_EFFECTS[effect_idx] 160 | instance.effects.append(effect) 161 | 162 | instance.orbs = [] 163 | for o in data.orbs: 164 | orb = types.Orb.deserialize(o) 165 | if orb.id != "NONE": 166 | instance.orbs.append(orb) 167 | 168 | instance.enemies = [] 169 | for e in data.enemies: 170 | enemy = types.Enemy.deserialize(e) 171 | if enemy.id != "NONE": 172 | instance.enemies.append(enemy) 173 | 174 | instance.hand = [] 175 | for hc in data.hand: 176 | hand_card = types.HandCard.deserialize(hc) 177 | if hand_card.id != CardCatalog.NONE.id: 178 | instance.hand.append(hand_card) 179 | 180 | instance.discard = [] 181 | for discard_idx, count in enumerate(data.discard): 182 | discard = types.Card.deserialize(discard_idx) 183 | if discard.id != CardCatalog.NONE.id: 184 | for _ in range(count): 185 | instance.discard.append(discard) 186 | instance.discard.sort() 187 | 188 | instance.draw = [] 189 | for draw_idx, count in enumerate(data.draw): 190 | draw = types.Card.deserialize(draw_idx) 191 | if draw.id != CardCatalog.NONE.id: 192 | for _ in range(count): 193 | instance.draw.append(draw) 194 | instance.draw.sort() 195 | 196 | instance.exhaust = [] 197 | for exhaust_idx, count in enumerate(data.exhaust): 198 | exhaust = types.Card.deserialize(exhaust_idx) 199 | if exhaust.id != CardCatalog.NONE.id: 200 | for _ in range(count): 201 | instance.exhaust.append(exhaust) 202 | instance.exhaust.sort() 203 | 204 | return instance 205 | 206 | def __eq__(self, other: object) -> bool: 207 | if not isinstance(other, CombatObs): 208 | return False 209 | 210 | attrs = [ 211 | "turn", 212 | "hand", 213 | "energy", 214 | "block", 215 | "effects", 216 | "orbs", 217 | "enemies", 218 | "discard", 219 | "draw", 220 | "exhaust", 221 | ] 222 | 223 | for attr in attrs: 224 | if getattr(self, attr) != getattr(other, attr): 225 | return False 226 | 227 | return True 228 | -------------------------------------------------------------------------------- /gym_sts/build/preferences/STSSeenCards: -------------------------------------------------------------------------------- 1 | { 2 | "Strike_R": "1", 3 | "Defend_R": "1", 4 | "Bash": "1", 5 | "Hemokinesis": "1", 6 | "Clothesline": "1", 7 | "Clash": "1", 8 | "Shrug It Off": "1", 9 | "Iron Wave": "1", 10 | "Reaper": "1", 11 | "Fiend Fire": "1", 12 | "Double Tap": "1", 13 | "Feel No Pain": "1", 14 | "Purity": "1", 15 | "Thinking Ahead": "1", 16 | "Berserk": "1", 17 | "Second Wind": "1", 18 | "Flame Barrier": "1", 19 | "Seeing Red": "1", 20 | "True Grit": "1", 21 | "Rampage": "1", 22 | "Twin Strike": "1", 23 | "Flex": "1", 24 | "Inflame": "1", 25 | "Impatience": "1", 26 | "Magnetism": "1", 27 | "Body Slam": "1", 28 | "Headbutt": "1", 29 | "Infernal Blade": "1", 30 | "Carnage": "1", 31 | "Juggernaut": "1", 32 | "Barricade": "1", 33 | "Wound": "1", 34 | "Fire Breathing": "1", 35 | "Power Through": "1", 36 | "Metallicize": "1", 37 | "Dazed": "1", 38 | "Combust": "1", 39 | "Pommel Strike": "1", 40 | "Warcry": "1", 41 | "Perfected Strike": "1", 42 | "Dual Wield": "1", 43 | "Cleave": "1", 44 | "Thunderclap": "1", 45 | "Rupture": "1", 46 | "Rage": "1", 47 | "Intimidate": "1", 48 | "Entrench": "1", 49 | "Shockwave": "1", 50 | "Armaments": "1", 51 | "Bloodletting": "1", 52 | "Sword Boomerang": "1", 53 | "Ghostly Armor": "1", 54 | "Strike_G": "1", 55 | "Defend_G": "1", 56 | "Survivor": "1", 57 | "Neutralize": "1", 58 | "Endless Agony": "1", 59 | "Underhanded Strike": "1", 60 | "PiercingWail": "1", 61 | "Sucker Punch": "1", 62 | "Backstab": "1", 63 | "Dagger Throw": "1", 64 | "Choke": "1", 65 | "Dodge and Roll": "1", 66 | "Leg Sweep": "1", 67 | "Blur": "1", 68 | "Crippling Poison": "1", 69 | "Flechettes": "1", 70 | "Quick Slash": "1", 71 | "Deadly Poison": "1", 72 | "Well Laid Plans": "1", 73 | "Madness": "1", 74 | "Slice": "1", 75 | "Outmaneuver": "1", 76 | "Dagger Spray": "1", 77 | "Poisoned Stab": "1", 78 | "Die Die Die": "1", 79 | "Glass Knife": "1", 80 | "Prepared": "1", 81 | "Footwork": "1", 82 | "Finesse": "1", 83 | "Transmutation": "1", 84 | "Burn": "1", 85 | "Bouncing Flask": "1", 86 | "Bullet Time": "1", 87 | "Night Terror": "1", 88 | "Malaise": "1", 89 | "Backflip": "1", 90 | "Blade Dance": "1", 91 | "Tactician": "1", 92 | "Infinite Blades": "1", 93 | "Trip": "1", 94 | "Dash": "1", 95 | "Finisher": "1", 96 | "Masterful Stab": "1", 97 | "Swift Strike": "1", 98 | "Shiv": "1", 99 | "Heel Hook": "1", 100 | "Calculated Gamble": "1", 101 | "Bandage Up": "1", 102 | "Predator": "1", 103 | "Envenom": "1", 104 | "Burst": "1", 105 | "All Out Attack": "1", 106 | "Flying Knee": "1", 107 | "Good Instincts": "1", 108 | "Mayhem": "1", 109 | "Sadistic Nature": "1", 110 | "Forethought": "1", 111 | "Enlightenment": "1", 112 | "Blind": "1", 113 | "HandOfGreed": "1", 114 | "Bane": "1", 115 | "Catalyst": "1", 116 | "Corpse Explosion": "1", 117 | "Strike_B": "1", 118 | "Defend_B": "1", 119 | "Zap": "1", 120 | "Dualcast": "1", 121 | "Hologram": "1", 122 | "Coolheaded": "1", 123 | "Go for the Eyes": "1", 124 | "Loop": "1", 125 | "Compile Driver": "1", 126 | "Leap": "1", 127 | "Consume": "1", 128 | "Sweeping Beam": "1", 129 | "Streamline": "1", 130 | "Reprogram": "1", 131 | "Cold Snap": "1", 132 | "All For One": "1", 133 | "Conserve Battery": "1", 134 | "Defragment": "1", 135 | "Secret Technique": "1", 136 | "Slimed": "1", 137 | "Chill": "1", 138 | "Reboot": "1", 139 | "Aggregate": "1", 140 | "Deflect": "1", 141 | "Distraction": "1", 142 | "Dramatic Entrance": "1", 143 | "Terror": "1", 144 | "Riddle With Holes": "1", 145 | "Acrobatics": "1", 146 | "Unload": "1", 147 | "Wraith Form v2": "1", 148 | "A Thousand Cuts": "1", 149 | "Skewer": "1", 150 | "Caltrops": "1", 151 | "BootSequence": "1", 152 | "Dropkick": "1", 153 | "Evaluate": "1", 154 | "EmptyFist": "1", 155 | "Injury": "1", 156 | "Flash of Steel": "1", 157 | "Halt": "1", 158 | "Panacea": "1", 159 | "Secret Weapon": "1", 160 | "Disarm": "1", 161 | "Burning Pact": "1", 162 | "Whirlwind": "1", 163 | "Chrysalis": "1", 164 | "Uppercut": "1", 165 | "Reckless Charge": "1", 166 | "Anger": "1", 167 | "Bludgeon": "1", 168 | "Impervious": "1", 169 | "Corruption": "1", 170 | "Deep Breath": "1", 171 | "Discovery": "1", 172 | "Apotheosis": "1", 173 | "Pummel": "1", 174 | "Demon Form": "1", 175 | "Dark Embrace": "1", 176 | "PanicButton": "1", 177 | "Dark Shackles": "1", 178 | "Battle Trance": "1", 179 | "Feed": "1", 180 | "Violence": "1", 181 | "Heavy Blade": "1", 182 | "Spot Weakness": "1", 183 | "Limit Break": "1", 184 | "Reflex": "1", 185 | "Doppelganger": "1", 186 | "Noxious Fumes": "1", 187 | "Tools of the Trade": "1", 188 | "Necronomicurse": "1", 189 | "Expertise": "1", 190 | "Phantasmal Killer": "1", 191 | "Eviscerate": "1", 192 | "Parasite": "1", 193 | "Adrenaline": "1", 194 | "Venomology": "1", 195 | "After Image": "1", 196 | "CurseOfTheBell": "1", 197 | "The Bomb": "1", 198 | "Metamorphosis": "1", 199 | "Heatsinks": "1", 200 | "Double Energy": "1", 201 | "Storm": "1", 202 | "Mind Blast": "1", 203 | "FTL": "1", 204 | "Reinforced Body": "1", 205 | "Glacier": "1", 206 | "Beam Cell": "1", 207 | "Steam": "1", 208 | "Ball Lightning": "1", 209 | "Doom and Gloom": "1", 210 | "Darkness": "1", 211 | "Steam Power": "1", 212 | "Electrodynamics": "1", 213 | "Multi-Cast": "1", 214 | "Thunder Strike": "1", 215 | "Skim": "1", 216 | "Redo": "1", 217 | "Seek": "1", 218 | "Self Repair": "1", 219 | "Master of Strategy": "1", 220 | "Gash": "1", 221 | "Creative AI": "1", 222 | "Amplify": "1", 223 | "Genetic Algorithm": "1", 224 | "Biased Cognition": "1", 225 | "Barrage": "1", 226 | "Machine Learning": "1", 227 | "Hello World": "1", 228 | "Static Discharge": "1", 229 | "Stack": "1", 230 | "Chaos": "1", 231 | "Buffer": "1", 232 | "Melter": "1", 233 | "Normality": "1", 234 | "Capacitor": "1", 235 | "Rebound": "1", 236 | "Undo": "1", 237 | "Echo Form": "1", 238 | "Strike_P": "1", 239 | "Defend_P": "1", 240 | "Eruption": "1", 241 | "Vigilance": "1", 242 | "Miracle": "1", 243 | "WheelKick": "1", 244 | "FlyingSleeves": "1", 245 | "CutThroughFate": "1", 246 | "WreathOfFlame": "1", 247 | "Adaptation": "1", 248 | "BattleHymn": "1", 249 | "EmptyBody": "1", 250 | "Worship": "1", 251 | "Smite": "1", 252 | "WaveOfTheHand": "1", 253 | "Perseverance": "1", 254 | "Protect": "1", 255 | "ConjureBlade": "1", 256 | "InnerPeace": "1", 257 | "Swivel": "1", 258 | "Nirvana": "1", 259 | "ClearTheMind": "1", 260 | "MasterReality": "1", 261 | "DeusExMachina": "1", 262 | "FollowUp": "1", 263 | "Sanctity": "1", 264 | "BowlingBash": "1", 265 | "ThirdEye": "1", 266 | "PathToVictory": "1", 267 | "CrushJoints": "1", 268 | "WindmillStrike": "1", 269 | "DevaForm": "1", 270 | "FlurryOfBlows": "1", 271 | "Meditate": "1", 272 | "Judgement": "1", 273 | "EmptyMind": "1", 274 | "Indignation": "1", 275 | "Consecrate": "1", 276 | "SignatureMove": "1", 277 | "Crescendo": "1", 278 | "SashWhip": "1", 279 | "Wallop": "1", 280 | "SandsOfTime": "1", 281 | "Ragnarok": "1", 282 | "Pray": "1", 283 | "Insight": "1", 284 | "Tantrum": "1", 285 | "Omniscience": "1", 286 | "Jack Of All Trades": "1", 287 | "Panache": "1", 288 | "TalkToTheHand": "1", 289 | "Conclude": "1", 290 | "FearNoEvil": "1", 291 | "LikeWater": "1", 292 | "Void": "1", 293 | "Prostrate": "1", 294 | "Blasphemy": "1", 295 | "Devotion": "1", 296 | "Escape Plan": "1", 297 | "Ghostly": "1", 298 | "Cloak And Dagger": "1", 299 | "Accuracy": "1", 300 | "Storm of Steel": "1", 301 | "Grand Finale": "1", 302 | "Sever Soul": "1", 303 | "Searing Blow": "1", 304 | "Offering": "1", 305 | "Fission": "1", 306 | "Blizzard": "1", 307 | "Scrape": "1", 308 | "Auto Shields": "1", 309 | "Turbo": "1", 310 | "Sunder": "1", 311 | "Meteor Strike": "1", 312 | "Force Field": "1", 313 | "Lockon": "1", 314 | "White Noise": "1", 315 | "Fusion": "1", 316 | "Rip and Tear": "1", 317 | "Rainbow": "1", 318 | "Hyperbeam": "1", 319 | "Recycle": "1", 320 | "Core Surge": "1", 321 | "Havoc": "1", 322 | "Blood for Blood": "1", 323 | "Pain": "1", 324 | "Immolate": "1", 325 | "Wild Strike": "1", 326 | "Writhe": "1", 327 | "Evolve": "1", 328 | "Brutality": "1", 329 | "Sentinel": "1", 330 | "Exhume": "1", 331 | "Regret": "1", 332 | "Doubt": "1", 333 | "Decay": "1", 334 | "Tempest": "1", 335 | "Brilliance": "1", 336 | "Weave": "1", 337 | "Establishment": "1", 338 | "CarveReality": "1", 339 | "Scrawl": "1", 340 | "JustLucky": "1", 341 | "Collect": "1", 342 | "Study": "1", 343 | "ForeignInfluence": "1", 344 | "Alpha": "1", 345 | "MentalFortress": "1", 346 | "Shame": "1", 347 | "ReachHeaven": "1", 348 | "DeceiveReality": "1", 349 | "ThroughViolence": "1", 350 | "RitualDagger": "1", 351 | "Fasting2": "1", 352 | "SpiritShield": "1", 353 | "Wish": "1", 354 | "Wireheading": "1" 355 | } 356 | -------------------------------------------------------------------------------- /gym_sts/spaces/observations/types/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | from gymnasium.spaces import Dict, Discrete, MultiBinary 8 | from pydantic import BaseModel, Field, validator 9 | 10 | import gym_sts.spaces.constants.base as base_consts 11 | import gym_sts.spaces.constants.combat as combat_consts 12 | import gym_sts.spaces.constants.shop as shop_consts 13 | from gym_sts.spaces.observations import spaces, utils 14 | 15 | 16 | BinaryArray = npt.NDArray[np.uint] 17 | 18 | 19 | class ShopMixin(BaseModel): 20 | price: int = Field(..., ge=0, lt=2**shop_consts.SHOP_LOG_MAX_PRICE) 21 | 22 | @staticmethod 23 | def serialize_empty_price(): 24 | return utils.to_binary_array(0, shop_consts.SHOP_LOG_MAX_PRICE) 25 | 26 | def serialize_price(self): 27 | return utils.to_binary_array(self.price, shop_consts.SHOP_LOG_MAX_PRICE) 28 | 29 | @staticmethod 30 | def deserialize_price(price: BinaryArray) -> int: 31 | return utils.from_binary_array(price) 32 | 33 | 34 | class Keys(BaseModel): 35 | emerald: bool = False 36 | ruby: bool = False 37 | sapphire: bool = False 38 | 39 | def serialize(self) -> BinaryArray: 40 | _keys = [self.ruby, self.emerald, self.sapphire] 41 | return np.array([int(key) for key in _keys]) 42 | 43 | @classmethod 44 | def deserialize(cls, data: BinaryArray) -> Keys: 45 | ruby, emerald, sapphire = data 46 | return cls(ruby=ruby, emerald=emerald, sapphire=sapphire) 47 | 48 | 49 | class Effect(BaseModel): 50 | id: str = "EMPTY" # Placeholder should always be replaced 51 | amount: int = Field(..., ge=-combat_consts.MAX_EFFECT, le=combat_consts.MAX_EFFECT) 52 | 53 | def serialize(self) -> dict: 54 | sign = 0 55 | value = self.amount 56 | 57 | if self.amount < 0: 58 | sign = 1 59 | value = -value 60 | 61 | return { 62 | "sign": sign, 63 | "value": utils.to_binary_array(value, combat_consts.LOG_MAX_EFFECT), 64 | } 65 | 66 | @staticmethod 67 | def serialize_all(effects: list[Effect]) -> list[dict]: 68 | serialized = [] 69 | effect_map = {effect.id: effect for effect in effects} 70 | 71 | for effect_id in combat_consts.ALL_EFFECTS: 72 | encoding = { 73 | "sign": 0, 74 | "value": utils.to_binary_array(0, combat_consts.LOG_MAX_EFFECT), 75 | } 76 | if effect_id in effect_map: 77 | effect = effect_map[effect_id] 78 | encoding = effect.serialize() 79 | 80 | serialized.append(encoding) 81 | 82 | return serialized 83 | 84 | class SerializedState(BaseModel): 85 | sign: int = Field(..., ge=0, le=1) 86 | value: BinaryArray 87 | 88 | class Config: 89 | arbitrary_types_allowed = True 90 | 91 | @classmethod 92 | def deserialize(cls, data: Union[dict, SerializedState]) -> Effect: 93 | if not isinstance(data, cls.SerializedState): 94 | data = cls.SerializedState(**data) 95 | 96 | amount = utils.from_binary_array(data.value) 97 | if data.sign: 98 | amount = -amount 99 | 100 | return cls(amount=amount) 101 | 102 | 103 | class Orb(BaseModel): 104 | id: str = "Empty" # STS seems to have a bug where empty orbs sometimes have no ID 105 | 106 | @staticmethod 107 | def serialize_empty() -> int: 108 | return combat_consts.ALL_ORBS.index("NONE") 109 | 110 | def serialize(self) -> int: 111 | return combat_consts.ALL_ORBS.index(self.id) 112 | 113 | @classmethod 114 | def deserialize(cls, orb_idx: int) -> Orb: 115 | orb_id = combat_consts.ALL_ORBS[orb_idx] 116 | return cls(id=orb_id) 117 | 118 | 119 | class Health(BaseModel): 120 | hp: int = Field(..., ge=0, le=base_consts.MAX_HP) 121 | max_hp: int = Field(..., ge=0, le=base_consts.MAX_HP) 122 | 123 | def serialize(self) -> dict: 124 | return { 125 | "hp": utils.to_binary_array(self.hp, base_consts.LOG_MAX_HP), 126 | "max_hp": utils.to_binary_array(self.max_hp, base_consts.LOG_MAX_HP), 127 | } 128 | 129 | class SerializedState(BaseModel): 130 | hp: BinaryArray 131 | max_hp: BinaryArray 132 | 133 | class Config: 134 | arbitrary_types_allowed = True 135 | 136 | @classmethod 137 | def deserialize(cls, data: Union[dict, SerializedState]) -> Health: 138 | if not isinstance(data, cls.SerializedState): 139 | data = cls.SerializedState(**data) 140 | 141 | hp = utils.from_binary_array(data.hp) 142 | max_hp = utils.from_binary_array(data.max_hp) 143 | 144 | return cls(hp=hp, max_hp=max_hp) 145 | 146 | 147 | class Attack(BaseModel): 148 | damage: int = Field(..., ge=0, le=combat_consts.MAX_ATTACK) 149 | times: int = Field(..., ge=0, lt=combat_consts.MAX_ATTACK_TIMES) 150 | 151 | @staticmethod 152 | def space() -> Dict: 153 | return Dict( 154 | { 155 | "damage": MultiBinary(combat_consts.LOG_MAX_ATTACK), 156 | "times": MultiBinary(combat_consts.LOG_MAX_ATTACK_TIMES), 157 | } 158 | ) 159 | 160 | def serialize(self): 161 | return { 162 | "damage": utils.to_binary_array(self.damage, combat_consts.LOG_MAX_ATTACK), 163 | "times": utils.to_binary_array( 164 | self.times, combat_consts.LOG_MAX_ATTACK_TIMES 165 | ), 166 | } 167 | 168 | class SerializedState(BaseModel): 169 | damage: BinaryArray 170 | times: BinaryArray 171 | 172 | class Config: 173 | arbitrary_types_allowed = True 174 | 175 | @classmethod 176 | def deserialize(cls, data: Union[dict, SerializedState]) -> Attack: 177 | if not isinstance(data, cls.SerializedState): 178 | data = cls.SerializedState(**data) 179 | 180 | return cls( 181 | damage=utils.from_binary_array(data.damage), 182 | times=utils.from_binary_array(data.times), 183 | ) 184 | 185 | 186 | class Enemy(BaseModel): 187 | id: str 188 | intent: str 189 | current_hp: int = Field(..., ge=0, le=base_consts.MAX_HP) 190 | max_hp: int = Field(..., ge=0, le=base_consts.MAX_HP) 191 | block: int = Field(..., ge=0, le=combat_consts.MAX_BLOCK) 192 | effects: list[Effect] = Field([], alias="powers") 193 | 194 | # These attribues may not be set if the player has runic dome 195 | damage: int = Field( 196 | 0, alias="move_adjusted_damage", ge=0, le=combat_consts.MAX_ATTACK 197 | ) 198 | times: int = Field(0, alias="move_hits", ge=0, le=combat_consts.MAX_ATTACK_TIMES) 199 | 200 | @validator("damage", pre=True) 201 | def must_be_nonnegative(cls, v: int) -> int: 202 | return max(0, v) 203 | 204 | @staticmethod 205 | def space() -> Dict: 206 | return Dict( 207 | { 208 | "id": Discrete(combat_consts.NUM_MONSTER_TYPES), 209 | "intent": Discrete(combat_consts.NUM_INTENTS), 210 | "attack": Attack.space(), 211 | "block": MultiBinary(combat_consts.LOG_MAX_BLOCK), 212 | "effects": spaces.generate_effect_space(), 213 | "health": spaces.generate_health_space(), 214 | } 215 | ) 216 | 217 | @staticmethod 218 | def serialize_empty() -> dict: 219 | serialized = { 220 | "id": 0, 221 | "intent": 0, 222 | "attack": Attack(damage=0, times=0).serialize(), 223 | "block": utils.to_binary_array(0, combat_consts.LOG_MAX_BLOCK), 224 | "effects": Effect.serialize_all([]), 225 | "health": Health(hp=0, max_hp=0).serialize(), 226 | } 227 | 228 | return serialized 229 | 230 | def serialize(self) -> dict: 231 | serialized = { 232 | "id": combat_consts.ALL_MONSTER_TYPES.index(self.id), 233 | "intent": combat_consts.ALL_INTENTS.index(self.intent), 234 | "attack": Attack(damage=self.damage, times=self.times).serialize(), 235 | "block": utils.to_binary_array(self.block, combat_consts.LOG_MAX_BLOCK), 236 | "effects": Effect.serialize_all(self.effects), 237 | "health": Health(hp=self.current_hp, max_hp=self.max_hp).serialize(), 238 | } 239 | 240 | return serialized 241 | 242 | class SerializedState(BaseModel): 243 | id: int 244 | intent: int 245 | attack: Attack.SerializedState 246 | block: BinaryArray 247 | effects: list[dict] 248 | health: Health.SerializedState 249 | 250 | class Config: 251 | arbitrary_types_allowed = True 252 | 253 | @classmethod 254 | def deserialize(cls, data: Union[dict, SerializedState]) -> Enemy: 255 | if not isinstance(data, cls.SerializedState): 256 | data = cls.SerializedState(**data) 257 | 258 | effects = [] 259 | for effect_idx, e in enumerate(data.effects): 260 | effect = Effect.deserialize(e) 261 | if effect.amount != 0: 262 | effect.id = combat_consts.ALL_EFFECTS[effect_idx] 263 | effects.append(effect) 264 | 265 | health = Health.deserialize(data.health) 266 | attack = Attack.deserialize(data.attack) 267 | 268 | return cls( 269 | id=combat_consts.ALL_MONSTER_TYPES[data.id], 270 | intent=combat_consts.ALL_INTENTS[data.intent], 271 | block=utils.from_binary_array(data.block), 272 | powers=effects, 273 | current_hp=health.hp, 274 | max_hp=health.max_hp, 275 | move_adjusted_damage=attack.damage, 276 | move_hits=attack.times, 277 | ) 278 | --------------------------------------------------------------------------------