├── data ├── Round_1.pt └── README.rst ├── doc └── static │ ├── workflow.png │ ├── fight_graph.png │ ├── fight_graph_full.png │ └── battle_graph_full.png ├── sapai ├── __init__.py ├── rand.py ├── model.py ├── play.py ├── status.py ├── tiers.py ├── compress.py ├── foods.py ├── graph.py ├── teams.py ├── lists.py ├── player.py ├── agents.py └── battle.py ├── setup.py ├── tests ├── test_teams.py ├── test_readme_code.py ├── test_agents.py ├── test_status.py ├── test_state.py ├── test_seeds.py ├── test_player.py ├── test_effects.py ├── test_shop.py ├── test_lists.py └── test_battles.py ├── .github └── workflows │ ├── lint.yaml │ ├── run_tests.yml │ └── build.yaml ├── LICENSE ├── .gitignore └── README.rst /data/Round_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manny405/sapai/HEAD/data/Round_1.pt -------------------------------------------------------------------------------- /doc/static/workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manny405/sapai/HEAD/doc/static/workflow.png -------------------------------------------------------------------------------- /doc/static/fight_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manny405/sapai/HEAD/doc/static/fight_graph.png -------------------------------------------------------------------------------- /doc/static/fight_graph_full.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manny405/sapai/HEAD/doc/static/fight_graph_full.png -------------------------------------------------------------------------------- /doc/static/battle_graph_full.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manny405/sapai/HEAD/doc/static/battle_graph_full.png -------------------------------------------------------------------------------- /sapai/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import data 2 | from .lists import Slot, SAPList 3 | from .pets import Pet 4 | from .foods import Food 5 | from .teams import Team, TeamSlot 6 | from .battle import Battle 7 | from .shop import Shop, ShopSlot 8 | from .player import Player 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from setuptools import find_packages 4 | from distutils.core import setup 5 | 6 | setup( 7 | name="sapai", 8 | version="0.1.0", 9 | packages=[ 10 | "sapai", 11 | ], 12 | # find_packages(exclude=[]), 13 | install_requires=["numpy", "torch", "graphviz"], 14 | data_files=[], 15 | ) 16 | -------------------------------------------------------------------------------- /tests/test_teams.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import unittest 3 | import numpy as np 4 | 5 | from sapai import * 6 | from sapai.compress import compress, decompress 7 | 8 | MIN = -5 9 | MAX = 20 10 | pet_names = list(data["pets"].keys()) 11 | 12 | 13 | class TestLists(unittest.TestCase): 14 | def test_team(self): 15 | l = Team([Pet(pet_names[x]) for x in range(3)]) 16 | 17 | 18 | # %% 19 | 20 | # test = TestLists() 21 | # test.test_remove() 22 | 23 | # %% 24 | -------------------------------------------------------------------------------- /tests/test_readme_code.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from sapai import * 4 | 5 | 6 | class TestReadMeCode(unittest.TestCase): 7 | def test_pet_creation(self): 8 | pet = Pet("ant") 9 | pet._attack += 3 10 | pet.gain_experience() 11 | 12 | def test_team_move(self): 13 | team0 = Team(["ant", "ox", "tiger"]) 14 | team1 = Team(["sheep", "tiger"]) 15 | team0.move(1, 4) 16 | team0.move_forward() 17 | 18 | def test_running_battle(self): 19 | team0 = Team(["ant", "ox", "tiger"]) 20 | team1 = Team(["sheep", "tiger"]) 21 | battle = Battle(team0, team1) 22 | winner = battle.battle() 23 | print(winner) 24 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | # Trigger the workflow on push or pull request, 5 | push: 6 | pull_request: 7 | 8 | jobs: 9 | run-linters: 10 | name: Run linters 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Check out Git repository 15 | uses: actions/checkout@v2 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v1 19 | with: 20 | python-version: 3.8 21 | 22 | - name: Install Python dependencies 23 | run: pip install black==22.6.0 24 | 25 | - name: Run linters 26 | uses: wearerequired/lint-action@v2 27 | with: 28 | black: true 29 | black_args: "sapai/*.py tests/*.py --check" 30 | commit: false 31 | continue_on_error: false 32 | -------------------------------------------------------------------------------- /sapai/rand.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class MockRandomState: 5 | """ 6 | Numpy RandomState is actually extremely slow, requiring about 300 microseconds 7 | for any operation involving state. Therefore, when reproducibility is not 8 | necessary, this class should be used to immensly improve efficiency. 9 | 10 | Tests were run for Player: 11 | %timeit _ = Player.from_state(pstate) 12 | # ORIGINAL: 26.8 ms ± 788 µs per loop 13 | # MockRandomState: 1.15 ms ± 107 µs per loop 14 | Use of MockRandomState improved the performance by 10x. This is very important. 15 | 16 | """ 17 | 18 | def __init__(self): 19 | pass 20 | 21 | def set_state(self): 22 | """Doesn't do anything""" 23 | return None 24 | 25 | def get_state(self): 26 | return None 27 | 28 | def choice(self, *args, **kwargs): 29 | return np.random.choice(*args, **kwargs) 30 | -------------------------------------------------------------------------------- /.github/workflows/run_tests.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | workflow_dispatch: 9 | 10 | jobs: 11 | build-linux: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | max-parallel: 5 15 | defaults: 16 | run: 17 | shell: bash -l {0} 18 | steps: 19 | - uses: actions/checkout@v2 20 | - name: Set up Python 3.11.1 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: 3.11.1 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install --user torch --extra-index-url https://download.pytorch.org/whl/cu117/ 28 | pip install flake8 pytest 29 | python setup.py install 30 | - name: Lint with flake8 31 | run: | 32 | # stop the build if there are Python syntax errors or undefined names 33 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 34 | - name: Test with pytest 35 | run: | 36 | pytest tests 37 | -------------------------------------------------------------------------------- /sapai/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Definition and implementation of the reinforment model that will be used to 4 | expertly learn how to play the game. 5 | 6 | Notes: 7 | - The capability to add rules to the model should be added. For example, 8 | only allowing the model to purchase certain animals. This will allow 9 | the capability to learn novel and more interesting strategies. 10 | - Beginning training model only with tier1 animals for the first few turns 11 | thereby making it easier to learn good initial behavior 12 | - Continue with tier2 animals and so on 13 | 14 | - Model makes selections based on a mask of possible behaviors that are 15 | available. For example, if there's no gold left, then the purchase 16 | and roll actions have a mask value of 0 and therefore the action 17 | cannot be performed. This will hopefully not be detremental to training. 18 | - In addition, biasing the training such that the desired value for these 19 | is zero may be a better approach. 20 | 21 | """ 22 | -------------------------------------------------------------------------------- /sapai/play.py: -------------------------------------------------------------------------------- 1 | # %% 2 | 3 | 4 | class Play: 5 | """ 6 | Input the player list and mode that the game should be played in. This class 7 | Defines the actions of an overall game including shopping, battles, and 8 | lives of the player. 9 | 10 | Tournament mode will be the most basic mode for learning purposes. This just 11 | constructs a tournament of players where each turn the losing is removed 12 | from the tournament. This type of game will optimize only the most ideal 13 | play. 14 | 15 | Arena mode will mimic the way the game seems to be played. In this, after 16 | each match, players will be pooled based on thier record. Only players with 17 | similar records may play one another. Once a player reaches 0 lives, their 18 | play will end. 19 | 20 | Versus mode will mimic the versus game-mode including the way that the game 21 | adds clones of characters to the game. 22 | 23 | """ 24 | 25 | def __init__(self, players=None, mode="tournament"): 26 | players = players or [] 27 | 28 | raise NotImplementedError 29 | 30 | 31 | # %% 32 | -------------------------------------------------------------------------------- /tests/test_agents.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from sapai.shop import * 4 | from sapai.agents import * 5 | 6 | 7 | class TestAgents(unittest.TestCase): 8 | def test_CombinatorialSearch(self): 9 | turn = 1 10 | player = Player( 11 | team=["ant", "fish", "beaver", "cricket", "horse"], 12 | shop=ShopLearn(turn=turn), 13 | turn=turn, 14 | ) 15 | cs = CombinatorialSearch() 16 | avail_actions = cs.avail_actions(player) 17 | 18 | for temp_action in avail_actions: 19 | if len(temp_action) == 0: 20 | temp_name = "None" 21 | else: 22 | temp_name = temp_action[0].__name__ 23 | 24 | if len(temp_action) > 1: 25 | temp_inputs = temp_action[1:] 26 | else: 27 | temp_inputs = [] 28 | 29 | def test_simple_CombinatorialSearch(self): 30 | turn = 1 31 | player = Player(shop=ShopLearn(turn=turn), turn=turn) 32 | player.gold = 10 33 | cs = CombinatorialSearch(max_actions=3) 34 | player_list, team_dict = cs.search(player) 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Manny Bier 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - "*" 7 | workflow_dispatch: 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-20.04 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python 3.6 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: 3.6 18 | 19 | - name: Install dependencies 20 | run: pip install wheel setuptools 21 | 22 | - name: Build wheel 23 | run: python setup.py bdist_wheel 24 | 25 | - name: Upload Python wheel 26 | uses: actions/upload-artifact@v2 27 | with: 28 | name: Python wheel 29 | path: ${{github.workspace}}/dist/sapai-*.whl 30 | if-no-files-found: error 31 | 32 | test: 33 | needs: build 34 | runs-on: ${{matrix.os}} 35 | strategy: 36 | max-parallel: 10 37 | matrix: 38 | python-version: [3.6] 39 | os: [ubuntu-20.04, windows-2019, macos-11] 40 | 41 | steps: 42 | - uses: actions/checkout@v2 43 | - name: Set up Python ${{matrix.python-version}} 44 | uses: actions/setup-python@v2 45 | with: 46 | python-version: ${{matrix.python-version}} 47 | 48 | - name: Download artifact 49 | uses: actions/download-artifact@master 50 | with: 51 | name: "Python wheel" 52 | 53 | - name: Install wheel 54 | run: pip install --find-links=${{github.workspace}} sapai 55 | 56 | - name: Test library accessibility 57 | run: python -c "import sapai" 58 | -------------------------------------------------------------------------------- /data/README.rst: -------------------------------------------------------------------------------- 1 | ==== 2 | Data 3 | ==== 4 | 5 | 6 | This folder contains data that may be helpful for training agents. 7 | 8 | ------- 9 | Round 1 10 | ------- 11 | 12 | The ``Round_1.pt`` contains a dictionary of all possible teams that can be constructed in round 1 of SAP and their relative win-rate against one-another. This file can be loaded into memory using the follow code. 13 | 14 | .. code-block:: python 15 | 16 | >>> import torch 17 | >>> results = torch.load("Round_1.pt") 18 | >>> print(f"Number of Possible Teams: {len(output_dict)}") 19 | 5013 20 | 21 | The key of the dictionary is a loss-less compression of the Team. The Team can be rebuilt in memory easily using ``sapai.compress.decompress``. For example, if you would like to examine the best possible teams from Round 1, the following code can be used. 22 | 23 | .. code-block:: python 24 | 25 | >>> import torch 26 | >>> import numpy as np 27 | >>> from sapai.compress import decompress 28 | >>> results = torch.load("Round_1.pt") 29 | >>> keys = np.array(list(results.keys())) 30 | >>> win_rate = np.array(list(results.values())) 31 | >>> sort_idx = np.argsort(win_rate)[::-1] 32 | >>> best_win_rate = win_rate[sort_idx[0]] 33 | >>> best_team = decompress(keys[sort_idx[0]]) 34 | >>> print(f"BEST WIN RATE: {best_win_rate:.3f}") 35 | 0.910 36 | >>> print(f"BEST TEAM: \n{best_team}") 37 | BEST TEAM: 38 | 0: < Slot pet-fish 2-3 none 1-0 > 39 | 1: < Slot pet-ant 2-1 none 1-0 > 40 | 2: < Slot pet-cricket 1-2 none 1-0 > 41 | 3: < Slot EMPTY > 42 | 4: < Slot EMPTY > 43 | -------------------------------------------------------------------------------- /tests/test_status.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import unittest 3 | import numpy as np 4 | 5 | from sapai import * 6 | from sapai.compress import compress, decompress 7 | from sapai.status import * 8 | 9 | # %% 10 | 11 | MIN = -10 12 | MAX = 100 13 | 14 | 15 | class TestStatus(unittest.TestCase): 16 | def test_damage(self): 17 | p = Pet("fish") 18 | for i in range(MIN, 0): 19 | self.assertEqual(p.get_damage(i), 0) 20 | for i in range(0, MAX): 21 | self.assertEqual(p.get_damage(i), i) 22 | 23 | def test_garlic_damage(self): 24 | p = Pet("fish") 25 | p.eat(Food("garlic")) 26 | for i in range(MIN, 0): 27 | self.assertEqual(p.get_damage(i), 0) 28 | for i in range(1, 3): 29 | self.assertEqual(p.get_damage(i), 1) 30 | for i in range(3, MAX): 31 | self.assertEqual(p.get_damage(i), (i - 2)) 32 | 33 | def test_melon_damage(self): 34 | p = Pet("fish") 35 | p.eat(Food("melon")) 36 | for i in range(MIN, 20): 37 | self.assertEqual(p.get_damage(i), 0) 38 | for i in range(21, MAX): 39 | self.assertEqual(p.get_damage(i), (i - 20)) 40 | 41 | def test_coconut_damage(self): 42 | p = Pet("fish") 43 | p.status = "status-coconut-shield" 44 | for i in range(MIN, MAX): 45 | self.assertEqual(p.get_damage(i), 0) 46 | 47 | def test_weak(self): 48 | p = Pet("fish") 49 | p.status = "status-weak" 50 | for i in range(MIN, 1): 51 | self.assertEqual(p.get_damage(i), 0) 52 | for i in range(1, MAX): 53 | self.assertEqual(p.get_damage(i), i + 3) 54 | 55 | 56 | # %% 57 | -------------------------------------------------------------------------------- /tests/test_state.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from sapai import * 4 | from sapai.shop import * 5 | from sapai.compress import * 6 | 7 | 8 | # TODO : In most tests, assert states are equal once `__eq__` is implemented for classes 9 | class TestState(unittest.TestCase): 10 | def test_pet_level_state(self): 11 | expected_level = 3 12 | p = Pet("ant") 13 | p.level = expected_level 14 | pet_from_state = Pet.from_state(p.state) 15 | self.assertEqual(pet_from_state.level, expected_level) 16 | 17 | def test_food_from_state(self): 18 | expected_food = Food("melon") 19 | actual_food = Food.from_state(expected_food.state) 20 | 21 | def test_team_from_state(self): 22 | expected_team = Team([Pet("fish"), Pet("dragon"), Pet("cat")]) 23 | expected_level = 3 24 | expected_team[0].pet.level = expected_level 25 | actual_team = Team.from_state(expected_team.state) 26 | self.assertEqual(actual_team[0].pet.level, expected_level) 27 | 28 | def test_shop_slot_state(self): 29 | actual = ShopSlot("pet") 30 | actual.roll() 31 | expected = ShopSlot.from_state(actual.state) 32 | 33 | def test_shop_state(self): 34 | actual = Shop() 35 | expected = Shop.from_state(actual.state) 36 | 37 | def test_player_state(self): 38 | expected = Player(team=Team(["ant", "fish", "dragon"])) 39 | actual = Player.from_state(expected.state) 40 | 41 | def test_compress_shop(self): 42 | expected = Shop() 43 | compressed = compress(expected, minimal=True) 44 | actual = decompress(compressed) 45 | 46 | def test_shop_state_equality(self): 47 | expected = Shop() 48 | actual = Shop.from_state(expected.state) 49 | self.assertEqual(expected.state, actual.state) 50 | 51 | def test_compress_player(self): 52 | expected = Player() 53 | compressed = compress(expected, minimal=True) 54 | actual = decompress(compressed) 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # File system files 2 | .DS_Store 3 | sapai/**/.DS_Store 4 | .vscode 5 | sapai/**/.vscode 6 | .idea/ 7 | 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | -------------------------------------------------------------------------------- /sapai/status.py: -------------------------------------------------------------------------------- 1 | def apply_(value): 2 | return max([value, 0]) 3 | 4 | 5 | def apply_garlic_armor(value): 6 | if value > 0: 7 | return max([value - 2, 1]) 8 | else: 9 | return 0 10 | 11 | 12 | def apply_melon_armor(value): 13 | return max([value - 20, 0]) 14 | 15 | 16 | def apply_coconut_shield(value): 17 | return 0 18 | 19 | 20 | def apply_bone_attack(value): 21 | if value > 0: 22 | return value + 4 23 | else: 24 | return 0 25 | 26 | 27 | def apply_steak_attack(value): 28 | if value > 0: 29 | return value + 20 30 | else: 31 | return 0 32 | 33 | 34 | def apply_weak(value): 35 | if value > 0: 36 | return value + 3 37 | else: 38 | return 0 39 | 40 | 41 | def apply_poison_attack(value): 42 | if value > 0: 43 | return 1000 44 | else: 45 | return 0 46 | 47 | 48 | def apply_splash_attack(value): 49 | return apply_(value) 50 | 51 | 52 | def apply_honey_bee(pet, team): 53 | raise NotImplementedError 54 | 55 | 56 | def apply_extra_life(pet, team): 57 | raise NotImplementedError 58 | 59 | 60 | apply_null_dict = { 61 | "none": apply_, 62 | "status-bone-attack": apply_, 63 | "status-coconut-shield": apply_, 64 | "status-extra-life": apply_, 65 | "status-garlic-armor": apply_, 66 | "status-honey-bee": apply_, 67 | "status-melon-armor": apply_, 68 | "status-poison-attack": apply_, 69 | "status-splash-attack": apply_, 70 | "status-steak-attack": apply_, 71 | "status-weak": apply_, 72 | } 73 | 74 | apply_damage_dict = { 75 | "none": apply_, 76 | "status-bone-attack": apply_, 77 | "status-coconut-shield": apply_coconut_shield, 78 | "status-extra-life": apply_, 79 | "status-garlic-armor": apply_garlic_armor, 80 | "status-honey-bee": apply_, 81 | "status-melon-armor": apply_melon_armor, 82 | "status-poison-attack": apply_, 83 | "status-splash-attack": apply_, 84 | "status-steak-attack": apply_, 85 | "status-weak": apply_weak, 86 | } 87 | 88 | apply_attack_dict = { 89 | "none": apply_, 90 | "status-bone-attack": apply_bone_attack, 91 | "status-coconut-shield": apply_, 92 | "status-extra-life": apply_, 93 | "status-garlic-armor": apply_, 94 | "status-honey-bee": apply_, 95 | "status-melon-armor": apply_, 96 | "status-poison-attack": apply_poison_attack, 97 | "status-splash-attack": apply_splash_attack, 98 | "status-steak-attack": apply_steak_attack, 99 | "status-weak": apply_, 100 | } 101 | 102 | apply_faint_dict = { 103 | "none": apply_, 104 | "status-bone-attack": apply_, 105 | "status-coconut-shield": apply_, 106 | "status-extra-life": apply_extra_life, 107 | "status-garlic-armor": apply_, 108 | "status-honey-bee": apply_honey_bee, 109 | "status-melon-armor": apply_, 110 | "status-poison-attack": apply_, 111 | "status-splash-attack": apply_, 112 | "status-steak-attack": apply_, 113 | "status-weak": apply_, 114 | } 115 | 116 | apply_once = { 117 | "status-coconut-shield", 118 | "status-melon-armor", 119 | "status-steak-attack", 120 | } 121 | -------------------------------------------------------------------------------- /sapai/tiers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sapai.data import data 3 | 4 | 5 | pet_tier_lookup = {1: [], 2: [], 3: [], 4: [], 5: [], 6: []} 6 | pet_tier_lookup_std = {1: [], 2: [], 3: [], 4: [], 5: [], 6: []} 7 | for key, value in data["pets"].items(): 8 | if type(value["tier"]) == int: 9 | pet_tier_lookup[value["tier"]].append(key) 10 | if "StandardPack" in value["packs"]: 11 | pet_tier_lookup_std[value["tier"]].append(key) 12 | pet_tier_avail_lookup = {1: [], 2: [], 3: [], 4: [], 5: [], 6: []} 13 | for key, value in pet_tier_lookup.items(): 14 | for temp_key, temp_value in pet_tier_avail_lookup.items(): 15 | if temp_key >= key: 16 | temp_value += value 17 | 18 | food_tier_lookup = {1: [], 2: [], 3: [], 4: [], 5: [], 6: []} 19 | for key, value in data["foods"].items(): 20 | if type(value["tier"]) == int: 21 | food_tier_lookup[value["tier"]].append(key) 22 | food_tier_avail_lookup = {1: [], 2: [], 3: [], 4: [], 5: [], 6: []} 23 | for key, value in food_tier_lookup.items(): 24 | for temp_key, temp_value in food_tier_avail_lookup.items(): 25 | if temp_key >= key: 26 | temp_value += value 27 | 28 | turn_prob_pets_std = {} 29 | turn_prob_pets_exp = {} 30 | for i in np.arange(0, 12): 31 | turn_prob_pets_std[i] = {} 32 | turn_prob_pets_exp[i] = {} 33 | for key, value in data["pets"].items(): 34 | if "probabilities" not in value: 35 | continue 36 | if data["pets"][key]["probabilities"] == "none": 37 | continue 38 | for temp_dict in data["pets"][key]["probabilities"]: 39 | temp_turn = int(temp_dict["turn"].split("-")[-1]) 40 | if "StandardPack" in temp_dict["perSlot"]: 41 | temp_std = temp_dict["perSlot"]["StandardPack"] 42 | turn_prob_pets_std[temp_turn][key] = temp_std 43 | if "ExpansionPack1" in temp_dict["perSlot"]: 44 | temp_exp = temp_dict["perSlot"]["ExpansionPack1"] 45 | turn_prob_pets_exp[temp_turn][key] = temp_exp 46 | else: 47 | ### Assumption, if expansion info not provided, use standard info 48 | temp_exp = temp_std 49 | turn_prob_pets_exp[temp_turn][key] = temp_exp 50 | 51 | turn_prob_foods_std = {} 52 | turn_prob_foods_exp = {} 53 | for i in np.arange(0, 12): 54 | turn_prob_foods_std[i] = {} 55 | turn_prob_foods_exp[i] = {} 56 | for key, value in data["foods"].items(): 57 | if "probabilities" not in value: 58 | continue 59 | if data["foods"][key]["probabilities"] == "none": 60 | continue 61 | for temp_dict in data["foods"][key]["probabilities"]: 62 | if temp_dict == "none": 63 | continue 64 | temp_turn = int(temp_dict["turn"].split("-")[-1]) 65 | if "StandardPack" in temp_dict["perSlot"]: 66 | temp_std = temp_dict["perSlot"]["StandardPack"] 67 | turn_prob_foods_std[temp_turn][key] = temp_std 68 | if "ExpansionPack1" in temp_dict["perSlot"]: 69 | temp_exp = temp_dict["perSlot"]["ExpansionPack1"] 70 | turn_prob_foods_exp[temp_turn][key] = temp_exp 71 | else: 72 | ### Assumption, if expansion info not provided, use standard info 73 | temp_exp = temp_std 74 | turn_prob_foods_exp[temp_turn][key] = temp_exp 75 | -------------------------------------------------------------------------------- /sapai/compress.py: -------------------------------------------------------------------------------- 1 | # %% 2 | 3 | import json, zlib 4 | import sapai 5 | 6 | 7 | def compress(obj, minimal=False): 8 | """ 9 | Will compress objects such that they can be stored and searched by a single 10 | string. This makes storage and quering of teams naively simple. 11 | 12 | """ 13 | state = getattr(obj, "state", False) 14 | if not state: 15 | raise Exception(f"No state method found for obj {obj}") 16 | if minimal: 17 | state = minimal_state(obj) 18 | json_str = json.dumps(state) 19 | compressed_str = zlib.compress(json_str.encode()) 20 | return compressed_str 21 | 22 | 23 | def decompress(compressed_str): 24 | """ 25 | Decompress the given encoded str into an object 26 | 27 | """ 28 | state_str = zlib.decompress(compressed_str).decode() 29 | state_dict = json.loads(state_str) 30 | return state2obj(state_dict) 31 | 32 | 33 | def state2obj(state): 34 | obj_type = state["type"] 35 | obj_cls = getattr(sapai, obj_type) 36 | obj = obj_cls.from_state(state) 37 | return obj 38 | 39 | 40 | def sapai_hash(obj): 41 | """ 42 | Fast method for hashing the object 43 | 44 | """ 45 | state = getattr(obj, "state", False) 46 | if not state: 47 | raise Exception(f"No state found for obj {obj}") 48 | raise Exception("I can't find faster way to do this... But it would be very nice.") 49 | 50 | 51 | def minimal_state(obj): 52 | """ 53 | Including the seed_state from food/pets and including the action history from 54 | player creates a 10 times increase in the compressed byte size. In many 55 | situations it is advantageous to create a minimal state for only the team 56 | stats and current shop pets. This will save memory/storage and improve 57 | computational efficiency. 58 | 59 | """ 60 | state = obj.state 61 | 62 | def minimal_pet_state(state): 63 | if "seed_state" in state: 64 | del state["seed_state"] 65 | 66 | def minimal_team_state(state): 67 | for teamslot_state in state["team"]: 68 | minimal_pet_state(teamslot_state["pet"]) 69 | 70 | def minimal_shop_state(state): 71 | if "seed_state" in state: 72 | del state["seed_state"] 73 | for shopslot_state in state["slots"]: 74 | if "seed_state" in shopslot_state: 75 | del shopslot_state["seed_state"] 76 | minimal_pet_state(shopslot_state["obj"]) 77 | 78 | def minimal_player_state(state): 79 | if "seed_state" in state: 80 | del state["seed_state"] 81 | if "action_history" in state: 82 | del state["action_history"] 83 | minimal_team_state(state["team"]) 84 | minimal_shop_state(state["shop"]) 85 | 86 | if state["type"] == "Pet": 87 | minimal_pet_state(state) 88 | elif state["type"] == "Food": 89 | minimal_pet_state(state) 90 | elif state["type"] == "Team": 91 | minimal_team_state(state) 92 | elif state["type"] == "Shop": 93 | minimal_shop_state(state) 94 | elif state["type"] == "Player": 95 | minimal_player_state(state) 96 | else: 97 | raise Exception(f"Unrecognized state type {state['type']}") 98 | 99 | return state 100 | 101 | 102 | # %% 103 | -------------------------------------------------------------------------------- /tests/test_seeds.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | 4 | from sapai import * 5 | from sapai.compress import compress, decompress 6 | 7 | 8 | class TestSeeds(unittest.TestCase): 9 | def test_battle_reproducibility(self): 10 | state = np.random.RandomState(seed=4).get_state() 11 | state2 = np.random.RandomState(seed=4).get_state() 12 | 13 | winner_set = set() 14 | for i in range(20): 15 | t0 = Team(["ant", "ant", "fish"], seed_state=state) 16 | t1 = Team(["ant", "ant", "fish"], seed_state=state2) 17 | b = Battle(t0, t1) 18 | winner = b.battle() 19 | # Same state should result in draw 20 | winner_set.add(winner) 21 | self.assertEqual(len(winner_set), 1) 22 | 23 | def test_battle_reproducibility_after_compress(self): 24 | state = np.random.RandomState(seed=4).get_state() 25 | state2 = np.random.RandomState(seed=4).get_state() 26 | 27 | winner_set = set() 28 | for i in range(20): 29 | t0 = Team(["ant", "ant", "fish"], seed_state=state) 30 | t1 = Team(["ant", "ant", "fish"], seed_state=state2) 31 | 32 | t0 = decompress(compress(t0)) 33 | t1 = decompress(compress(t1)) 34 | 35 | b = Battle(t0, t1) 36 | winner = b.battle() 37 | # Same state should result in draw 38 | winner_set.add(winner) 39 | self.assertEqual(len(winner_set), 1) 40 | 41 | def test_shop_reproducibility(self): 42 | state = np.random.RandomState(seed=20).get_state() 43 | s = Shop(turn=11, seed_state=state) 44 | # ref_init_state = s.state 45 | 46 | # Setup solution 47 | shop_check_list = [] 48 | s = Shop(turn=11, seed_state=state) 49 | for i in range(10): 50 | names = [] 51 | for slot in s: 52 | names.append(slot.obj.name) 53 | names = tuple(names) 54 | shop_check_list.append(names) 55 | s.roll() 56 | shop_check_list = tuple(shop_check_list) 57 | 58 | # Run check for reproducibility 59 | for i in range(10): 60 | s = Shop(turn=11, seed_state=state) 61 | temp_check_list = [] 62 | for i in range(10): 63 | names = [] 64 | for slot in s: 65 | names.append(slot.obj.name) 66 | names = tuple(names) 67 | temp_check_list.append(names) 68 | s.roll() 69 | self.assertEqual(tuple(temp_check_list), shop_check_list) 70 | 71 | for i in range(10): 72 | s = Shop(turn=11, seed_state=state) 73 | s = decompress(compress(s)) 74 | # self.assertEqual(s.state, ref_init_state) 75 | temp_check_list = [] 76 | for i in range(10): 77 | names = [] 78 | for slot in s: 79 | names.append(slot.obj.name) 80 | names = tuple(names) 81 | temp_check_list.append(names) 82 | s.roll() 83 | self.assertEqual(tuple(temp_check_list), shop_check_list) 84 | 85 | 86 | # from sapai.compress import compress, decompress 87 | # state = np.random.RandomState(seed=20).get_state() 88 | # s = Shop(turn=11, seed_state=state) 89 | # test_s = decompress(compress(s)) 90 | -------------------------------------------------------------------------------- /tests/test_player.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import unittest 3 | 4 | from sapai import * 5 | 6 | 7 | class TestPlayer(unittest.TestCase): 8 | def setUp(self) -> None: 9 | self.pack = "StandardPack" 10 | 11 | def test_buy_three_animals(self): 12 | player = Player(pack=self.pack) 13 | player.buy_pet(player.shop[0]) 14 | player.buy_pet(player.shop[0]) 15 | player.buy_pet(player.shop[0]) 16 | 17 | def test_buy_two_animals_one_food(self): 18 | player = Player(pack=self.pack) 19 | player.buy_pet(player.shop[0]) 20 | player.buy_pet(player.shop[0]) 21 | player.buy_food(player.shop[-1], player.team[0]) 22 | player.sell(player.team[0]) 23 | 24 | def test_freeze(self): 25 | player = Player(pack=self.pack) 26 | player.freeze(0) 27 | player.shop.roll() 28 | 29 | def test_unfreeze(self): 30 | player = Player(pack=self.pack) 31 | player.unfreeze(0) 32 | player.shop.roll() 33 | 34 | def test_buy_combine_behavior(self): 35 | player = Player( 36 | shop=["ant", "fish", "fish", "apple"], team=["fish", "ant"], pack=self.pack 37 | ) 38 | player.buy_combine(player.shop[1], player.team[0]) 39 | player.buy_combine(player.shop[1], player.team[0]) 40 | 41 | def test_buy_combine_behavior2(self): 42 | player = Player( 43 | shop=["ant", "octopus", "octopus", "apple"], 44 | team=["octopus", "ant"], 45 | pack=self.pack, 46 | ) 47 | player.buy_combine(player.shop[1], player.team[0]) 48 | player.buy_combine(player.shop[1], player.team[0]) 49 | 50 | def test_combine_behavior(self): 51 | player = Player( 52 | shop=["ant", "fish", "fish", "apple"], 53 | team=["fish", "fish", "fish", "horse"], 54 | pack=self.pack, 55 | ) 56 | player.combine(player.team[0], player.team[1]) 57 | player.combine(player.team[0], player.team[2]) 58 | 59 | def test_cat_behavior(self): 60 | player = Player( 61 | shop=["ant", "fish", "fish", "pear"], team=["fish", "cat"], pack=self.pack 62 | ) 63 | player.buy_food(player.shop[-1], player.team[0]) 64 | 65 | def test_start_of_turn_behavior(self): 66 | player = Player( 67 | shop=["ant", "fish", "fish", "pear"], 68 | team=["dromedary", "swan", "caterpillar", "squirrel"], 69 | pack=self.pack, 70 | ) 71 | player.team[0]._pet.level = 2 72 | player.start_turn() 73 | 74 | def test_sell_buy_behavior(self): 75 | player = Player( 76 | shop=["otter", "fish", "fish", "pear"], 77 | team=["pig", "fish", "ant", "beaver", "pig"], 78 | pack=self.pack, 79 | ) 80 | player.sell_buy(0, 0) 81 | 82 | def test_pill_behavior(self): 83 | player = Player( 84 | shop=["ant", "fish", "fish", "food-sleeping-pill"], 85 | team=["rooster", "ant", "cricket", "sheep"], 86 | pack=self.pack, 87 | ) 88 | player.buy_food(player.shop[-1], player.team[1]) 89 | 90 | def test_multi_faints(self): 91 | player = Player( 92 | shop=["ant", "fish", "fish", "food-sleeping-pill"], 93 | team=["hedgehog", "ant", "ox", "sheep", "dragon"], 94 | pack=self.pack, 95 | ) 96 | player.buy_food(player.shop[-1], player.team[0]) 97 | 98 | def test_deer_microbe_shark(self): 99 | player = Player( 100 | shop=["ant", "fish", "fish", "food-sleeping-pill"], 101 | team=["deer", "microbe", "eagle", "shark"], 102 | pack=self.pack, 103 | ) 104 | player.buy_food(player.shop[-1], player.team[2]) 105 | 106 | def test_deer_badger_fly_sheep(self): 107 | player = Player( 108 | shop=["ant", "fish", "fish", "food-sleeping-pill"], 109 | team=["deer", "badger", "sheep", "fly"], 110 | pack=self.pack, 111 | ) 112 | player.buy_food(player.shop[-1], player.team[1]) 113 | 114 | 115 | # %% 116 | -------------------------------------------------------------------------------- /sapai/foods.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import numpy as np 3 | 4 | from sapai.data import data 5 | from sapai.rand import MockRandomState 6 | 7 | # %% 8 | 9 | 10 | class Food: 11 | def __init__(self, name="food-none", shop=None, seed_state=None): 12 | """ 13 | Food class definition the types of interactions that food undergoes 14 | 15 | """ 16 | if len(name) != 0: 17 | if not name.startswith("food-"): 18 | name = f"food-{name}" 19 | 20 | self.eaten = False 21 | self.shop = shop 22 | 23 | self.seed_state = seed_state 24 | if self.seed_state is not None: 25 | self.rs = np.random.RandomState() 26 | self.rs.set_state(self.seed_state) 27 | else: 28 | ### Otherwise, set use 29 | self.rs = MockRandomState() 30 | 31 | self.attack = 0 32 | self.health = 0 33 | self.base_attack = 0 34 | self.base_health = 0 35 | self.apply_until_end_of_battle = False 36 | self.status = "none" 37 | self.effect = "none" 38 | self.fd = {} 39 | 40 | self.name = name 41 | if name not in data["foods"]: 42 | raise Exception(f"Food {name} not found") 43 | 44 | self.cost = 3 45 | item = data["foods"][name] 46 | if "cost" in item: 47 | self.cost = item["cost"] 48 | 49 | fd = item["ability"] 50 | self.fd = fd 51 | 52 | self.attack = 0 53 | self.health = 0 54 | self.effect = fd["effect"] 55 | if "attackAmount" in fd["effect"]: 56 | self.attack = fd["effect"]["attackAmount"] 57 | self.base_attack = fd["effect"]["attackAmount"] 58 | if "healthAmount" in fd["effect"]: 59 | self.health = fd["effect"]["healthAmount"] 60 | self.base_health = fd["effect"]["healthAmount"] 61 | if "status" in fd["effect"]: 62 | self.status = fd["effect"]["status"] 63 | if ( 64 | "untilEndOfBattle" in fd["effect"] 65 | and fd["effect"]["untilEndOfBattle"] is True 66 | ): 67 | self.apply_until_end_of_battle = True 68 | 69 | def copy(self): 70 | copy_food = Food(self.name, self.shop) 71 | for key, value in self.__dict__.items(): 72 | ### Although this approach will copy the internal dictionaries by 73 | ### reference rather than copy by value, these dictionaries will 74 | ### never be modified anyways. 75 | ### All integers and strings are copied by value automatically with 76 | ### Python, therefore, this achieves the correct behavior 77 | copy_food.__dict__[key] = value 78 | return copy_food 79 | 80 | @property 81 | def state(self): 82 | #### Ensure that state can be JSON serialized 83 | if getattr(self, "rs", False): 84 | if isinstance(self.rs, MockRandomState): 85 | seed_state = None 86 | else: 87 | seed_state = list(self.rs.get_state()) 88 | seed_state[1] = seed_state[1].tolist() 89 | else: 90 | seed_state = None 91 | state_dict = { 92 | "type": "Food", 93 | "name": self.name, 94 | "eaten": self.eaten, 95 | "attack": self.attack, 96 | "health": self.health, 97 | "apply_until_end_of_battle": self.apply_until_end_of_battle, 98 | "seed_state": seed_state, 99 | } 100 | return state_dict 101 | 102 | @classmethod 103 | def from_state(cls, state): 104 | food = cls(name=state["name"]) 105 | food.attack = state["attack"] 106 | food.health = state["health"] 107 | food.eaten = state["eaten"] 108 | food.apply_until_end_of_battle = state["apply_until_end_of_battle"] 109 | ### Supply seed_state in state dict should be optional 110 | if "seed_state" in state: 111 | if state["seed_state"] is not None: 112 | food.seed_state = state["seed_state"] 113 | food.rs = np.random.RandomState() 114 | food.rs.set_state(state["seed_state"]) 115 | return food 116 | 117 | def __repr__(self): 118 | return f"< {self.name} {self.attack}-{self.health} {self.status} >" 119 | 120 | 121 | # %% 122 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ===== 2 | sapai 3 | ===== 4 | 5 | 6 | | |license| |test-status| |coverage| 7 | 8 | 9 | ``sapai`` is a Super Auto Pets engine built with reinforcement learning training and other related AI models in mind. You may see more of my published work in machine learning on `ResearchGate `_ or `ACS `_. 10 | 11 | You may see and use ``sapai`` examples easily through `Google Colab `_ 12 | 13 | .. figure:: doc/static/workflow.png 14 | :height: 380 15 | :width: 404 16 | :align: center 17 | 18 | 19 | .. contents:: 20 | :local: 21 | 22 | ------------ 23 | Installation 24 | ------------ 25 | 26 | To start installing and using ``sapai``, it's highly recommended to start from an Anaconda distribution of Python, which can be downloaded for free here_. 27 | 28 | .. _here: https://www.anaconda.com/products/individual 29 | 30 | Then download the library from Github. A ``zip`` file can be downloaded using the green download code button. Alternatively, this repository can be obtained using the following command from the command-line. 31 | 32 | .. code-block:: bash 33 | 34 | git clone https://github.com/manny405/sapai.git 35 | 36 | After navigating to the ``sapai`` directory, installation is completed with the following command. 37 | 38 | .. code-block:: bash 39 | 40 | python setup.py install 41 | 42 | Unit tests are located in the ``tests`` directory. Tests can be run with the following command 43 | 44 | .. code-block:: bash 45 | 46 | python -m unittest discover -s tests 47 | 48 | 49 | --------------------------- 50 | Introduction: Code Examples 51 | --------------------------- 52 | 53 | The following code exampes will be run through the Python shell. To start a Python shell session, open up your preferred command-line program, such as Terminal or Powershell, then type and enter ``python``. 54 | 55 | ############### 56 | Creating a Pet 57 | ############### 58 | 59 | .. code-block:: python 60 | 61 | >>> from sapai.pets import Pet 62 | >>> pet = Pet("ant") 63 | >>> print(pet) 64 | ### Printing pet is given in the form of < PetName Attack-Health Status Level-Exp > 65 | < pet-ant 2-1 none 1-0 > 66 | >>> pet._attack += 3 67 | >>> pet.gain_experience() 68 | >>> print(pet) 69 | < pet-ant 5-1 none 1-1 > 70 | >>> print(pet.ability) 71 | ### Organization of pet abilities provided by super-auto-pets-db project 72 | {'description': 'Faint: Give a random friend +2/+1', 73 | 'trigger': 'Faint', 74 | 'triggeredBy': {'kind': 'Self'}, 75 | 'effect': {'kind': 'ModifyStats', 76 | 'attackAmount': 2, 77 | 'healthAmount': 1, 78 | 'target': {'kind': 'RandomFriend', 'n': 1}, 79 | 'untilEndOfBattle': False}} 80 | 81 | 82 | ############### 83 | Creating a Team 84 | ############### 85 | 86 | .. code-block:: python 87 | 88 | >>> from sapai.pets import Pet 89 | >>> from sapai.teams import Team 90 | >>> team0 = Team(["ant","ox","tiger"]) 91 | >>> team1 = Team(["sheep","tiger"]) 92 | >>> print(team0) 93 | 0: < Slot pet-ant 2-1 none 1-0 > 94 | 1: < Slot pet-ox 1-4 none 1-0 > 95 | 2: < Slot pet-tiger 4-3 none 1-0 > 96 | 3: < Slot EMPTY > 97 | 4: < Slot EMPTY > 98 | >>> print(team1) 99 | 0: < Slot pet-sheep 2-2 none 1-0 > 100 | 1: < Slot pet-tiger 4-3 none 1-0 > 101 | 2: < Slot EMPTY > 102 | 3: < Slot EMPTY > 103 | 4: < Slot EMPTY > 104 | >>> team0.move(1,4) 105 | >>> print(team0) 106 | 0: < Slot pet-ant 2-1 none 1-0 > 107 | 1: < Slot EMPTY > 108 | 2: < Slot pet-tiger 4-3 none 1-0 > 109 | 3: < Slot EMPTY > 110 | 4: < Slot pet-ox 1-4 none 1-0 > 111 | >>> team0.move_forward() 112 | >>> print(team0) 113 | 0: < Slot pet-ant 2-1 none 1-0 > 114 | 1: < Slot pet-tiger 4-3 none 1-0 > 115 | 2: < Slot pet-ox 1-4 none 1-0 > 116 | 3: < Slot EMPTY > 117 | 4: < Slot EMPTY > 118 | 119 | ####### 120 | Battles 121 | ####### 122 | 123 | .. code-block:: python 124 | 125 | ### Using the teams created in the last section 126 | >>> from sapai.battle import Battle 127 | >>> battle = Battle(team0,team1) 128 | >>> winner = battle.battle() 129 | >>> print(winner) 130 | 2 131 | ### Possible fight outputs: 132 | ### 0 = Team0 Wins 133 | ### 1 = Team1 Wins 134 | ### 2 = Draw 135 | 136 | The implementation of battle is efficient. Using IPython magic, this can be tested using the following IPython method: 137 | 138 | .. code-block:: python 139 | 140 | from sapai.pets import Pet 141 | from sapai.teams import Team 142 | from sapai.battle import Battle 143 | team0 = Team(["ant","ox","tiger"]) 144 | team1 = Team(["sheep","tiger"]) 145 | 146 | def timing_test(): 147 | b = Battle(team0,team1) 148 | winner = b.battle() 149 | 150 | %timeit timing_test() 151 | ### On 2019 Macbook Pro: 152 | ### 8.12 ms ± 450 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 153 | ### More than 100 battle per second on a single core 154 | 155 | ### On Xeon Platinum 8124M @ 3.00GHz 156 | ### 6.06 ms ± 49.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 157 | ### More than 150 battle per second on a single core 158 | 159 | ### On 2021 MPB with M1 Pro Processor: 160 | ### 4.32 ms ± 20.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 161 | ### More than 230 battle per second on a single core 162 | 163 | ^^^^^^^^^^^^ 164 | Battle Graph 165 | ^^^^^^^^^^^^ 166 | 167 | All battle history is stored for every phase, effect, and attack that occured during the battle. This battle history can be graphed and visualized. The full graph for the battle is shown below. 168 | 169 | >>> from sapai.graph import graph_battle 170 | >>> graph_battle(battle, file_name="Example") 171 | 172 | 173 | .. figure:: doc/static/battle_graph_full.png 174 | 175 | :height: 2140 176 | :width: 536 177 | :align: center 178 | 179 | ------ 180 | Status 181 | ------ 182 | 183 | Ongoing 184 | 185 | 1. See the issues page for ongoing discussions. The code-base is completely ready for the development of AI engines around SAP. 186 | 187 | 188 | .. |license| image:: https://img.shields.io/badge/License-MIT-yellow.svg 189 | .. |test-status| image:: https://github.com/manny405/sapai/actions/workflows/run_tests.yml/badge.svg 190 | .. |coverage| image:: https://codecov.io/gh/manny405/sapai/branch/main/graph/badge.svg?token=5RDE13SYET 191 | -------------------------------------------------------------------------------- /tests/test_effects.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from sapai import * 4 | from sapai.battle import * 5 | from sapai.effects import * 6 | from sapai.compress import * 7 | 8 | 9 | class TestEffects(unittest.TestCase): 10 | def test_target_function(self): 11 | shop = Shop() 12 | shop.roll() 13 | t0 = Team(["fish", "dragon", "owl", "ant", "badger"], shop=shop) 14 | t1 = Team(["cat", "hippo", "horse", "ant"], shop=shop) 15 | for slot in t0: 16 | slot.pet.shop = shop 17 | for slot in t1: 18 | slot.pet.shop = shop 19 | t1[0].pet._health = 7 20 | t0[0].pet.level = 2 21 | t0[1].pet.level = 3 22 | t0[0].pet._health = 8 23 | 24 | def test_all_functions(self): 25 | all_func = [x for x in func_dict.keys()] 26 | pet_func = {} 27 | for pet, fd in data["pets"].items(): 28 | if "level1Ability" not in fd: 29 | continue 30 | kind = fd["level1Ability"]["effect"]["kind"] 31 | if kind not in pet_func: 32 | pet_func[kind] = [] 33 | pet_func[kind].append(pet) 34 | 35 | shop = Shop() 36 | shop.roll() 37 | base_team = Team(["fish", "dragon"]) 38 | for slot in base_team: 39 | slot.pet.shop = shop 40 | slot.pet.team = base_team 41 | tc = compress(base_team) 42 | for func_name in all_func: 43 | print(func_name) 44 | if func_name not in pet_func: 45 | continue 46 | for pet in pet_func[func_name]: 47 | temp_team = decompress(tc) 48 | temp_team.append(pet) 49 | temp_team[2].pet.shop = shop 50 | temp_team.append("bison") 51 | temp_team[3].pet.shop = shop 52 | temp_enemy_team = decompress(tc) 53 | func = get_effect_function(temp_team[2]) 54 | apet = temp_team[2].pet 55 | apet_idx = [0, 2] 56 | teams = [temp_team, temp_enemy_team] 57 | te = temp_team[2].pet 58 | te_idx = [0, 2] 59 | fixed_targets = [] 60 | 61 | if func_name == "RepeatAbility": 62 | te = temp_team[1].pet 63 | if func_name == "FoodMultiplier": 64 | te = Food("pear") 65 | targets, possible = func( 66 | apet, apet_idx, teams, te, te_idx, fixed_targets 67 | ) 68 | 69 | def test_tiger_func(self): 70 | t = Team(["spider", "tiger"], battle=True) 71 | slot_list = [x for x in t] 72 | for slot in slot_list: 73 | if slot.empty: 74 | continue 75 | slot.pet.faint_trigger(slot.pet, [0, t.index(slot)]) 76 | 77 | def test_eagle_stats(self): 78 | # seed for Snake 6/6 79 | state = np.random.RandomState(seed=4).get_state() 80 | 81 | pet = Pet("eagle", seed_state=state) 82 | pet.level = 3 83 | t = Team([pet], battle=True) 84 | t[0].pet.faint_trigger(t[0].pet, [0, t.index(t[0])]) 85 | 86 | # should spawn Snake Lvl3 18/18 since Eagle was lvl 3 87 | self.assertEqual(t[0].level, 3) 88 | self.assertEqual(t[0].attack, 18) 89 | self.assertEqual(t[0].health, 18) 90 | 91 | def test_multiple_cats(self): 92 | player = Player(shop=Shop(["pear"]), team=Team([Pet("cat")])) 93 | player.buy_food(0, 0) 94 | 95 | # should add +4/+4 96 | self.assertEqual(player.team[0].attack, 8) 97 | self.assertEqual(player.team[0].health, 9) 98 | 99 | player = Player(shop=Shop(["pear"]), team=Team([Pet("cat"), Pet("cat")])) 100 | player.buy_food(0, 0) 101 | 102 | # should add +6/+6 103 | self.assertEqual(player.team[0].attack, 10) 104 | self.assertEqual(player.team[0].health, 11) 105 | 106 | player = Player( 107 | shop=Shop(["pear"]), team=Team([Pet("cat"), Pet("cat"), Pet("cat")]) 108 | ) 109 | player.buy_food(0, 0) 110 | 111 | # should add +8/+8 112 | self.assertEqual(player.team[0].attack, 12) 113 | self.assertEqual(player.team[0].health, 13) 114 | 115 | def test_melon(self): 116 | leopard = Pet("leopard") 117 | leopard.status = "status-melon-armor" 118 | leopard._health = 50 119 | leopard._attack = 50 120 | fish = Pet("fish") 121 | fish.status = "status-melon-armor" 122 | fish._health = 50 123 | fish._attack = 50 124 | 125 | t0 = Team([leopard], battle=True) 126 | t1 = Team([fish], battle=True) 127 | 128 | leopard.sob_trigger(t1) 129 | self.assertEqual(fish.health, 45) # 25 damage, -20 melon 130 | self.assertEqual(fish.status, "none") 131 | 132 | attack_phase = get_attack(leopard, fish) 133 | self.assertEqual(attack_phase[1], 30) # fish hits melon 134 | self.assertEqual(leopard.status, "none") 135 | 136 | def test_garlic(self): 137 | fish = Pet("fish") 138 | fish.status = "status-garlic-armor" 139 | fish._health = 50 140 | fish._attack = 50 141 | 142 | t = Team(["dolphin", "otter", "mosquito"], battle=True) 143 | t2 = Team([fish], battle=True) 144 | t[0].pet.sob_trigger(t2) 145 | self.assertEqual(fish.health, 47) # 5 damage, -2 garlic 146 | 147 | t[2].pet.sob_trigger(t2) 148 | self.assertEqual(fish.health, 46) # should still do 1 damage 149 | 150 | attack_phase = get_attack(t[0].pet, fish) 151 | self.assertEqual(attack_phase[0], 2) # dolphin 4/6 152 | 153 | attack_phase = get_attack(t[1].pet, fish) 154 | self.assertEqual(attack_phase[0], 1) # otter 1/2 155 | 156 | def test_coconut(self): 157 | gorilla = Pet("gorilla") 158 | gorilla.status = "status-coconut-shield" 159 | t = Team([gorilla], battle=True) 160 | t2 = Team(["crocodile"], battle=True) 161 | t3 = Team(["dragon"], battle=True) 162 | 163 | t2[0].pet.sob_trigger(t) 164 | self.assertEqual(gorilla.health, 9) # unchanged 165 | self.assertEqual(gorilla.status, "none") 166 | 167 | gorilla.status = "status-coconut-shield" 168 | attack_phase = get_attack(gorilla, t3[0].pet) 169 | self.assertEqual(attack_phase[1], 0) # dragon hits coconut 170 | self.assertEqual(gorilla.status, "none") 171 | 172 | def test_weak(self): 173 | fish = Pet("fish") 174 | fish.status = "status-weak" 175 | fish._health = 50 176 | fish._attack = 50 177 | t = Team([fish], battle=True) 178 | t2 = Team(["dolphin"], battle=True) 179 | t3 = Team(["dragon"], battle=True) 180 | 181 | t2[0].pet.sob_trigger(t) 182 | self.assertEqual(fish.health, 42) # 5 + 3 183 | 184 | attack_phase = get_attack(fish, t3[0].pet) 185 | self.assertEqual(attack_phase[1], 9) # 6/8 + 3 186 | 187 | def test_hatching_chick_level_3(self): 188 | hc = Pet("hatching-chick") 189 | hc.level = 3 190 | t = Team(["dragon", hc]) 191 | hc.sot_trigger(t) 192 | self.assertEqual(t[0].pet.experience, 1) 193 | -------------------------------------------------------------------------------- /tests/test_shop.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import unittest 3 | 4 | from sapai import Player, Team 5 | from sapai.shop import * 6 | 7 | 8 | class TestShop(unittest.TestCase): 9 | def test_shop_slot_pet(self): 10 | slot = ShopSlot("pet") 11 | slot.obj = Pet("ant") 12 | slot.roll() 13 | self.assertIsInstance(slot.obj, Pet) 14 | 15 | def test_shop_slot_food(self): 16 | slot = ShopSlot("food") 17 | slot.obj = Food("apple") 18 | slot.roll() 19 | self.assertIsInstance(slot.obj, Food) 20 | 21 | def test_shop_level_up(self): 22 | slot = ShopSlot("levelup") 23 | tier = slot.obj.tier 24 | self.assertEqual(tier, 2) 25 | 26 | def test_max_shop(self): 27 | s = Shop(turn=11) 28 | s.freeze(0) 29 | ref_state = s[0].state 30 | for index in range(10): 31 | s.roll() 32 | self.assertEqual(ref_state, s[0].state) 33 | 34 | def test_rabbit_buy_food(self): 35 | test_player = Player(shop=["honey"], team=["rabbit"]) 36 | start_health = 2 37 | self.assertEqual(test_player.team[0].pet.health, start_health) 38 | 39 | test_player.buy_food(0, 0) 40 | expected_end_health = start_health + 1 41 | self.assertEqual(test_player.team[0].pet.health, expected_end_health) 42 | 43 | def test_empty_shop_from_state(self): 44 | pet = Pet("fish") 45 | orig_shop = Shop(slots=[pet]) 46 | orig_shop.buy(pet) 47 | self.assertEqual(len(orig_shop.slots), 0) 48 | 49 | copy_shop = Shop.from_state(orig_shop.state) 50 | self.assertEqual(len(copy_shop.slots), 0) 51 | 52 | def test_combine_scorpions(self): 53 | player = Player(team=["scorpion", "scorpion"]) 54 | player.combine(0, 1) 55 | 56 | def test_combine_coconut_shield(self): 57 | gorilla = Pet("gorilla") 58 | gorilla.status = "status-coconut-shield" 59 | gorilla2 = Pet("gorilla") 60 | gorilla2.status = "status-melon-armor" 61 | player = Player(team=[gorilla, gorilla2]) 62 | player.combine(1, 0) 63 | self.assertEqual( 64 | gorilla.status, "status-coconut-shield" 65 | ) # same priority, therefore pet-to-keep keeps its status 66 | 67 | def test_squirrel(self): 68 | player = Player(team=Team([Pet("squirrel")])) 69 | player.start_turn() 70 | self.assertEqual(player.shop[3].cost, 2) 71 | 72 | player.roll() 73 | self.assertEqual(player.shop[3].cost, 3) 74 | 75 | def test_pill_1gold(self): 76 | player = Player(shop=Shop(["sleeping-pill"]), team=Team(["fish"])) 77 | player.buy_food(0, 0) 78 | self.assertEqual(player.gold, 9) 79 | 80 | def test_cupcake(self): 81 | player = Player(shop=Shop(["cupcake"]), team=Team([Pet("fish")])) 82 | 83 | player.buy_food(0, 0) 84 | self.assertEqual(player.team[0].attack, 5) # fish 2/2 85 | self.assertEqual(player.team[0].health, 5) 86 | 87 | player.end_turn() 88 | player.start_turn() 89 | 90 | self.assertEqual(player.team[0].attack, 2) 91 | self.assertEqual(player.team[0].health, 2) 92 | 93 | def test_apple(self): 94 | player = Player(shop=Shop(["apple"]), team=Team([Pet("beaver")])) 95 | 96 | player.buy_food(0, 0) 97 | self.assertEqual(player.team[0].attack, 4) 98 | self.assertEqual(player.team[0].health, 3) 99 | 100 | def test_shop_levelup_from_combine(self): 101 | player = Player(shop=Shop(["fish", "fish"]), team=Team([Pet("fish")])) 102 | player.buy_combine(1, 0) 103 | player.buy_combine(0, 0) 104 | self.assertEqual(len(player.shop), 1) 105 | 106 | def test_shop_levelup_from_ability(self): 107 | pet = Pet("caterpillar") 108 | pet.level = 2 109 | pet.experience = 2 110 | player = Player(shop=Shop([]), team=Team([pet])) 111 | pet.sot_trigger() 112 | self.assertEqual(len(player.shop.filled), 5) 113 | 114 | def test_buy_multi_target_food(self): 115 | player = Player(shop=["sushi"], team=["seal", "rabbit", "ladybug"]) 116 | player.buy_food(0) 117 | self.assertEqual(player.team[0].attack, 4) # 3 + sushi 118 | self.assertEqual(player.team[0].health, 10) # 8 + sushi + rabbit 119 | self.assertEqual(player.team[1].attack, 3) # 1 + sushi + seal 120 | self.assertEqual(player.team[1].health, 5) # 2 + sushi + seal + rabbit 121 | self.assertEqual(player.team[2].attack, 4) # 1 + sushi + seal + ladybug 122 | self.assertEqual( 123 | player.team[2].health, 7 124 | ) # 3 + sushi + seal + rabbit + ladybug 125 | 126 | def test_buy_multi_target_food_empty_team(self): 127 | player = Player(shop=["sushi"], team=[]) 128 | player.buy_food(0) 129 | 130 | def test_buy_chocolate(self): 131 | player = Player(shop=["chocolate"], team=["seal", "rabbit", "ladybug"]) 132 | player.buy_food(0, 0) 133 | self.assertEqual(player.team[0].pet.experience, 1) 134 | self.assertEqual(player.team[0].attack, 3) # 3 135 | self.assertEqual(player.team[0].health, 9) # 8 + rabbit 136 | self.assertEqual(player.team[1].attack, 2) # 1 + seal 137 | self.assertEqual(player.team[1].health, 3) # 2 + seal 138 | self.assertEqual(player.team[2].attack, 3) # 1 + seal + ladybug 139 | self.assertEqual(player.team[2].health, 5) # 3 + seal + ladybug 140 | 141 | def test_buy_apple(self): 142 | player = Player(shop=["apple"], team=["seal", "rabbit", "ladybug"]) 143 | player.buy_food(0, 0) 144 | self.assertEqual(player.team[0].attack, 4) # 3 + apple 145 | self.assertEqual(player.team[0].health, 10) # 8 + apple + rabbit 146 | self.assertEqual(player.team[1].attack, 2) # 1 + seal 147 | self.assertEqual(player.team[1].health, 3) # 2 + seal 148 | self.assertEqual(player.team[2].attack, 3) # 1 + seal + ladybug 149 | self.assertEqual(player.team[2].health, 5) # 3 + seal + ladybug 150 | 151 | def test_chicken(self): 152 | state = np.random.RandomState(seed=1).get_state() 153 | player = Player(shop=Shop(["fish", "fish"], seed_state=state), team=["chicken"]) 154 | player.buy_pet(0) 155 | self.assertEqual(player.shop[0].obj.attack, 3) # fish 2/2 156 | self.assertEqual(player.shop[0].obj.health, 3) 157 | 158 | ### check result after 1 roll 159 | player.roll() 160 | self.assertEqual(player.shop[0].obj.attack, 3) # duck 2/3 161 | self.assertEqual(player.shop[0].obj.health, 4) 162 | 163 | ### check result in a new turn 164 | player.end_turn() 165 | player.start_turn() 166 | 167 | self.assertEqual(player.shop[0].obj.attack, 3) # mosquito 2/2 168 | self.assertEqual(player.shop[0].obj.health, 3) 169 | 170 | def test_canned_food(self): 171 | state = np.random.RandomState(seed=1).get_state() 172 | player = Player( 173 | shop=Shop(["fish", "canned-food"], seed_state=state), team=["fish"] 174 | ) 175 | player.buy_food(1) 176 | 177 | ### check immediate result 178 | self.assertEqual(player.shop[0].obj.attack, 4) # fish 2/2 179 | self.assertEqual(player.shop[0].obj.health, 3) 180 | 181 | ### check result after 1 roll 182 | player.roll() 183 | self.assertEqual(player.shop[0].obj.attack, 4) # duck 2/3 184 | self.assertEqual(player.shop[0].obj.health, 4) 185 | 186 | ### check result in a new turn 187 | player.end_turn() 188 | player.start_turn() 189 | self.assertEqual(player.shop[0].obj.attack, 4) # mosquito 2/2 190 | self.assertEqual(player.shop[0].obj.health, 3) 191 | 192 | 193 | # %% 194 | -------------------------------------------------------------------------------- /sapai/graph.py: -------------------------------------------------------------------------------- 1 | # %% 2 | 3 | from graphviz import Digraph 4 | from sapai import Pet 5 | 6 | 7 | def html_table( 8 | header="", 9 | entries=None, 10 | table_attr=None, 11 | header_font_attr=None, 12 | header_bg_color="#1C6EA4", 13 | cell_font_attr=None, 14 | cell_border=0, 15 | column_align=None, 16 | cell_bg_colors=None, 17 | ): 18 | entries = entries or [[]] 19 | table_attr = table_attr or [("BORDER", "1")] 20 | header_font_attr = header_font_attr or [("COLOR", "#000000")] 21 | cell_font_attr = cell_font_attr or [] 22 | column_align = column_align or ["RIGHT", "LEFT"] 23 | cell_bg_colors = cell_bg_colors or [[]] 24 | 25 | table_attr_str = "" 26 | for attr, value in table_attr: 27 | table_attr_str += f""" {attr}="{value}" """ 28 | table_str = f"<" 29 | 30 | num_rows = 0 31 | if len(header) > 0: 32 | num_rows += 1 33 | 34 | if len(entries) != 0: 35 | if type(entries[0]) != list: 36 | raise Exception("Entries argument should be list of lists.") 37 | else: 38 | num_rows += len(entries) 39 | 40 | if len(cell_bg_colors) != 0: 41 | if type(cell_bg_colors[0]) != list: 42 | raise Exception("Argument cell_bg_colors should be list of lists") 43 | for iter_idx, temp_entry in enumerate(cell_bg_colors): 44 | temp_length = len(temp_entry) 45 | temp_check_length = len(entries[iter_idx]) 46 | if temp_length != 0: 47 | raise NotImplementedError 48 | if temp_length != temp_check_length: 49 | raise Exception("Must supply one cell_bg_color for every entry") 50 | 51 | num_columns = 0 52 | for entry in entries: 53 | temp_columns = len(entry) 54 | num_columns = max(num_columns, temp_columns) 55 | 56 | ### Build string starting with header 57 | if len(header) > 0: 58 | temp_str = " " 59 | 60 | temp_font_attr_str = "" 61 | for attr, value in header_font_attr: 62 | temp_font_attr_str += f""" {attr}="{value}" """ 63 | 64 | temp_str += f"""""" 65 | temp_str += "" 66 | table_str += temp_str 67 | 68 | cell_font_attr_str = "" 69 | for attr, value in cell_font_attr: 70 | cell_font_attr_str += f""" {attr}="{value}" """ 71 | 72 | for entry in entries: 73 | temp_str = " " 74 | for column_idx in range(num_columns): 75 | if column_idx < len(entry): 76 | temp_cell_str = entry[column_idx] 77 | else: 78 | temp_cell_str = "" 79 | temp_align = column_align[column_idx % len(column_align)] 80 | temp_str += f"""""" 81 | temp_str += "" 82 | table_str += temp_str 83 | 84 | table_str += "
{header}
{temp_cell_str}
>" 85 | return table_str 86 | 87 | 88 | def prep_pet_str(pstr): 89 | if isinstance(pstr, Pet): 90 | pstr = str(pstr) 91 | temp_pstr = pstr.replace("<", "") 92 | temp_pstr = temp_pstr.replace(">", "") 93 | temp_pstr = temp_pstr.replace(" Slot ", "") 94 | temp_pstr = temp_pstr.replace("pet-", "") 95 | # temp_pstr = temp_pstr.replace("EMPTY "," ") 96 | temp_pstr = temp_pstr.replace("none", "") 97 | temp_pstr = " ".join(temp_pstr.split()) 98 | temp_pstr_list = temp_pstr.split() 99 | if len(temp_pstr_list) == 2: 100 | temp_stats = temp_pstr_list[1] 101 | replace_index = temp_stats.index("-") 102 | temp_pstr_list[ 103 | 1 104 | ] = f"{temp_stats[0:replace_index]},{temp_stats[replace_index + 1 :]}" 105 | temp_pstr = " ".join(temp_pstr_list) 106 | return temp_pstr 107 | 108 | 109 | def prep_pet_str_obj(obj): 110 | """Recursive function to prepare pet string""" 111 | if type(obj) == str: 112 | return prep_pet_str(obj) 113 | elif type(obj) == list: 114 | ret_obj = [] 115 | for temp_entry in obj: 116 | ret_obj += [prep_pet_str_obj(temp_entry)] 117 | return ret_obj 118 | else: 119 | raise Exception(type(obj)) 120 | 121 | 122 | def prep_effect(effect_list): 123 | effect_name = effect_list[0] 124 | team_idx = effect_list[1][0] 125 | pet_idx = effect_list[1][1] 126 | pet_str = prep_pet_str(effect_list[2]) 127 | target_str = prep_pet_str_obj(effect_list[3]) 128 | target_str = " ".join(target_str) 129 | if len(target_str) == 0: 130 | target_str = " " 131 | effect_columns = [ 132 | "Effect", 133 | "Team", 134 | " Pet Index ", 135 | " Activating Pet ", 136 | "Targets", 137 | ] 138 | return ( 139 | [effect_name, str(team_idx), str(pet_idx), pet_str, target_str], 140 | effect_columns, 141 | ) 142 | 143 | 144 | def graph_battle(f, file_name="", verbose=False): 145 | g = Digraph(graph_attr={"rankdir": "TB", "clusterrank": "local"}) 146 | prev_node = None 147 | node_idx = 0 148 | for turn_name, phase_dict in f.battle_history.items(): 149 | if turn_name == "init": 150 | pstr = prep_pet_str_obj(phase_dict) 151 | pstr[0] = ["Team 0: "] + pstr[0] 152 | pstr[1] = ["Team 1: "] + pstr[1] 153 | temp_table = html_table( 154 | header="Initial Teams", 155 | entries=pstr, 156 | table_attr=[("BORDER", "1")], 157 | header_font_attr=[("COLOR", "#FFFFFF")], 158 | header_bg_color="#1C6EA4", 159 | cell_font_attr=[], 160 | cell_border=0, 161 | column_align=[ 162 | "RIGHT", 163 | "CENTER", 164 | "CENTER", 165 | "CENTER", 166 | "CENTER", 167 | "CENTER", 168 | ], 169 | cell_bg_colors=[[]], 170 | ) 171 | g.node( 172 | str(node_idx), style="rounded,invisible", shape="box", label=temp_table 173 | ) 174 | prev_node = str(node_idx) 175 | node_idx += 1 176 | continue 177 | 178 | turn_name = turn_name[0].upper() + turn_name[1:] + " Turn" 179 | ### Should really do nested tables for start and each attack. That's next. 180 | for phase_name, phase_entry in phase_dict.items(): 181 | if "phase_move" in phase_name: 182 | if phase_name == "phase_move_start": 183 | header = f"{turn_name} Phase: Move-Team-Start" 184 | if not verbose: 185 | continue 186 | elif phase_name == "phase_move_end": 187 | header = f"{turn_name} Phase: Move-Team-End" 188 | pstr = prep_pet_str_obj(phase_entry) 189 | ### Only interested in final positions 190 | if len(pstr) > 0: 191 | pstr = pstr[-1] 192 | else: 193 | pstr = [[], []] 194 | pstr[0] = ["Team 0: "] + pstr[0] 195 | pstr[1] = ["Team 1: "] + pstr[1] 196 | entries = pstr 197 | 198 | elif "phase_start" == phase_name: 199 | header = f"{turn_name} Phase: Start Fight" 200 | entries = [] 201 | for iter_idx, temp_effect_info in enumerate(phase_entry): 202 | es, ec = prep_effect(temp_effect_info) 203 | if iter_idx == 0: 204 | entries.append(ec) 205 | entries.append(es) 206 | if len(entries) == 0: 207 | if not verbose: 208 | continue 209 | entries.append( 210 | [ 211 | "Effect", 212 | "Team", 213 | " Pet Index ", 214 | " Activating Pet ", 215 | "Targets", 216 | ] 217 | ) 218 | 219 | elif "phase_hurt_and_faint" in phase_name: 220 | header = f"{turn_name} Phase: Hurt and Faint" 221 | entries = [] 222 | if len(phase_entry) != 0: 223 | for iter_idx, temp_effect_info in enumerate(phase_entry): 224 | es, ec = prep_effect(temp_effect_info) 225 | if iter_idx == 0: 226 | entries.append(ec) 227 | entries.append(es) 228 | if len(entries) == 0: 229 | if not verbose: 230 | continue 231 | entries.append( 232 | [ 233 | "Effect", 234 | "Team", 235 | " Pet Index ", 236 | " Activating Pet ", 237 | "Targets", 238 | ] 239 | ) 240 | 241 | elif phase_name in [ 242 | "phase_attack_before", 243 | "phase_attack_after", 244 | "phase_attack", 245 | "phase_knockout", 246 | ]: 247 | if phase_name == "phase_attack_before": 248 | header = f"{turn_name} Phase: Before Attack" 249 | elif phase_name == "phase_attack": 250 | header = f"{turn_name} Phase: Attack" 251 | elif phase_name == "phase_attack_after": 252 | header = f"{turn_name} Phase: Attack After" 253 | elif phase_name == "phase_knockout": 254 | header = f"{turn_name} Phase: Knockout" 255 | entries = [] 256 | for temp_effect_info in phase_entry: 257 | es, ec = prep_effect(temp_effect_info) 258 | if len(entries) == 0: 259 | entries.append(ec) 260 | entries.append(es) 261 | if len(entries) == 0: 262 | if not verbose: 263 | continue 264 | entries.append( 265 | [ 266 | "Effect", 267 | "Team", 268 | " Pet Index ", 269 | " Activating Pet ", 270 | "Targets", 271 | ] 272 | ) 273 | 274 | else: 275 | continue 276 | 277 | temp_table = html_table( 278 | header=header, 279 | entries=entries, 280 | table_attr=[("BORDER", "1")], 281 | header_font_attr=[("COLOR", "#FFFFFF")], 282 | header_bg_color="#1C6EA4", 283 | cell_font_attr=[], 284 | cell_border=0, 285 | column_align=["CENTER", "CENTER", "CENTER", "CENTER", "CENTER"], 286 | cell_bg_colors=[[]], 287 | ) 288 | g.node( 289 | str(node_idx), style="rounded,invisible", shape="box", label=temp_table 290 | ) 291 | g.edge(prev_node, str(node_idx)) 292 | prev_node = str(node_idx) 293 | node_idx = node_idx + 1 294 | 295 | if len(file_name) > 0: 296 | g.render(filename=file_name) 297 | return g 298 | 299 | 300 | # %% 301 | -------------------------------------------------------------------------------- /tests/test_lists.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import unittest 3 | import numpy as np 4 | 5 | from sapai import * 6 | from sapai.compress import compress, decompress 7 | from sapai.lists import Slot, SAPList 8 | 9 | MIN = -5 10 | MAX = 20 11 | pet_names = list(data["pets"].keys()) 12 | 13 | 14 | class TestLists(unittest.TestCase): 15 | def test_empty_slot(self): 16 | s = Slot() 17 | self.assertIsNone(s.obj) 18 | self.assertTrue(s.empty) 19 | 20 | s = Slot(Slot()) 21 | self.assertIsNone(s.obj) 22 | self.assertTrue(s.empty) 23 | 24 | def test_pet_slot(self): 25 | s = Slot(Pet("fish")) 26 | s.obj = Pet("fish") 27 | self.assertEqual(s.obj.state, Pet("fish").state) 28 | self.assertFalse(s.empty) 29 | s = Slot(Slot(Pet("fish"))) 30 | self.assertEqual(s.obj.state, Pet("fish").state) 31 | self.assertFalse(s.empty) 32 | s.obj = Slot(Pet("fish")) 33 | self.assertEqual(s.obj.state, Pet("fish").state) 34 | self.assertFalse(s.empty) 35 | 36 | def test_state(self): 37 | s = Slot() 38 | self.assertEqual(s.state, {"type": "Slot"}) 39 | s = Slot(Pet("fish")) 40 | self.assertEqual(s.state, {"type": "Slot", "obj": Pet("fish").state}) 41 | 42 | def test_compress(self): 43 | s = Slot() 44 | compress(s) 45 | s = Slot(Pet("fish")) 46 | c = compress(s) 47 | test = decompress(c) 48 | self.assertEqual(s.state, test.state) 49 | 50 | def test_list_state(self): 51 | l = SAPList() 52 | self.assertEqual(l.state, SAPList.from_state(l.state).state) 53 | l = SAPList([Pet(pet_names[x]) for x in range(5)], nslots=5) 54 | self.assertEqual(l.state, SAPList.from_state(l.state).state) 55 | 56 | def test_internal_list_length(self): 57 | l = SAPList() 58 | self.assertEqual(len(l), 0) 59 | 60 | l = SAPList(nslots=5) 61 | self.assertEqual(len(l._slots), 5) 62 | 63 | for n in range(MIN, MAX): 64 | try: 65 | l = SAPList(nslots=n) 66 | if n < 1: 67 | raise Exception(f"SAPList should fail for {n}") 68 | except Exception as e: 69 | if n > 0: 70 | raise Exception(f"SAPList should work for {n}: {e}") 71 | continue 72 | 73 | self.assertEqual(len(l._slots), n) 74 | self.assertEqual(len([x for x in l]), n) 75 | 76 | def test_list_length(self): 77 | for i in range(0, MAX): 78 | items = [Pet("fish") for x in range(i)] 79 | l = SAPList(items) 80 | self.assertEqual(len(l), i) 81 | 82 | for n in range(1, MAX): 83 | l = SAPList(items) 84 | l.nslots = n 85 | if len(items) >= n: 86 | ### if there were more items than slots, then the behavior 87 | ### is to only have [:n] subset of slots left 88 | self.assertEqual(len(l.empty), 0) 89 | else: 90 | ### if there were more slots than items, then len should be 91 | ### original number of items 92 | self.assertEqual(len(l.empty), n - i) 93 | 94 | def test_setitem_list(self): 95 | for i in range(0, MAX): 96 | items = [Pet("fish") for x in range(i)] 97 | l = SAPList(items) 98 | for n in range(0, MAX): 99 | try: 100 | l[n] = Pet("fish") 101 | except Exception as e: 102 | if n < i: 103 | raise Exception(f"Indexing should work for i={i} & n={n}: {e}") 104 | 105 | def test_empty_idx(self): 106 | rand_idx = np.arange(0, MAX).astype(int) 107 | rand_test_size = np.random.randint(0, high=MAX, size=(MAX,)) 108 | rand_buffer_size = np.random.randint(0, high=MAX, size=(MAX,)) 109 | for i, rand_size in enumerate(rand_test_size): 110 | list_length = int(MAX + rand_buffer_size[i]) 111 | l = SAPList(nslots=list_length) 112 | fill_idx = np.random.choice(rand_idx, size=(rand_size,), replace=False) 113 | l[fill_idx] = [Pet("fish") for x in range(rand_size)] 114 | bit_set = np.ones(list_length) 115 | bit_set[fill_idx] = 0 116 | test_idx = list(np.where(bit_set == 1)[0]) 117 | self.assertEqual(l.empty, test_idx) 118 | 119 | def test_left_right(self): 120 | for n in range(0, MAX): 121 | l = SAPList(nslots=MAX) 122 | l[n] = Pet("fish") 123 | test_slot, test_idx = l.leftmost 124 | self.assertEqual(test_idx, n) 125 | test_slot, test_idx = l.rightmost 126 | self.assertEqual(test_idx, n) 127 | 128 | def test_move(self): 129 | for n in range(0, MAX): 130 | l = SAPList(nslots=MAX) 131 | l[0] = Pet("fish") 132 | try: 133 | l.move(0, n) 134 | except Exception as e: 135 | if n != 0: 136 | raise Exception(f"Move should not fail for n={n}: {e}") 137 | self.assertEqual(l[n].state, Slot(Pet("fish")).state) 138 | 139 | l = SAPList(nslots=MAX) 140 | l[n] = Pet("fish") 141 | l.move_forward(sidx=0, eidx=-1) 142 | self.assertEqual(l[0].state, Slot(Pet("fish")).state) 143 | 144 | l = SAPList(nslots=MAX) 145 | l[n] = Pet("fish") 146 | l.move_backward() 147 | self.assertEqual(l[-1].state, Slot(Pet("fish")).state) 148 | 149 | l = SAPList([Pet(pet_names[x]) for x in range(5)], nslots=5) 150 | ref = SAPList([Pet(pet_names[x]) for x in range(5)], nslots=5) 151 | l.move_forward() 152 | self.assertEqual(l.state, ref.state) 153 | 154 | l = SAPList(nslots=5) 155 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 156 | ref = SAPList(nslots=5) 157 | ref[0, 1, 2] = [Pet(pet_names[x]) for x in range(3)] 158 | l.move_forward() 159 | self.assertEqual(l.state, ref.state) 160 | 161 | l = SAPList(nslots=5) 162 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 163 | ref = SAPList(nslots=5) 164 | ref[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 165 | l.move_forward(0, 1) 166 | self.assertEqual(l.state, ref.state) 167 | 168 | l = SAPList(nslots=5) 169 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 170 | ref = SAPList(nslots=5) 171 | ref[0, 1, 4] = [Pet(pet_names[x]) for x in range(3)] 172 | l.move_forward(0, 2) 173 | self.assertEqual(l.state, ref.state) 174 | 175 | l = SAPList(nslots=5) 176 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 177 | ref = SAPList(nslots=5) 178 | ref[0, 1, 4] = [Pet(pet_names[x]) for x in range(3)] 179 | l.move_forward(0, 3) 180 | self.assertEqual(l.state, ref.state) 181 | 182 | l = SAPList(nslots=5) 183 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 184 | ref = SAPList(nslots=5) 185 | ref[0, 1, 4] = [Pet(pet_names[x]) for x in range(3)] 186 | l.move_forward(1, 3) 187 | self.assertEqual(l.state, ref.state) 188 | 189 | l = SAPList(nslots=5) 190 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 191 | ref = SAPList(nslots=5) 192 | ref[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 193 | l.move_forward(2, 3) 194 | self.assertEqual(l.state, ref.state) 195 | 196 | l = SAPList(nslots=5) 197 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 198 | ref = SAPList(nslots=5) 199 | ref[0, 2, 3] = [Pet(pet_names[x]) for x in range(3)] 200 | l.move_forward(2, 4) 201 | self.assertEqual(l.state, ref.state) 202 | 203 | l = SAPList(nslots=5) 204 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 205 | ref = SAPList(nslots=5) 206 | ref[2, 3, 4] = [Pet(pet_names[x]) for x in range(3)] 207 | l.move_backward() 208 | self.assertEqual(l.state, ref.state) 209 | 210 | l = SAPList(nslots=5) 211 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 212 | ref = SAPList(nslots=5) 213 | ref[1, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 214 | l.move_backward(0, 1) 215 | self.assertEqual(l.state, ref.state) 216 | 217 | l = SAPList(nslots=5) 218 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 219 | ref = SAPList(nslots=5) 220 | ref[1, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 221 | l.move_backward(0, 2) 222 | self.assertEqual(l.state, ref.state) 223 | 224 | l = SAPList(nslots=5) 225 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 226 | ref = SAPList(nslots=5) 227 | ref[2, 3, 4] = [Pet(pet_names[x]) for x in range(3)] 228 | l.move_backward(0, 3) 229 | self.assertEqual(l.state, ref.state) 230 | 231 | l = SAPList(nslots=5) 232 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 233 | ref = SAPList(nslots=5) 234 | ref[0, 3, 4] = [Pet(pet_names[x]) for x in range(3)] 235 | l.move_backward(1, 4) 236 | self.assertEqual(l.state, ref.state) 237 | 238 | l = SAPList(nslots=5) 239 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 240 | ref = SAPList(nslots=5) 241 | ref[0, 3, 4] = [Pet(pet_names[x]) for x in range(3)] 242 | l.move_backward(2, 4) 243 | self.assertEqual(l.state, ref.state) 244 | 245 | l = SAPList(nslots=5) 246 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 247 | ref = SAPList(nslots=5) 248 | ref[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 249 | l.move_backward(3, 4) 250 | self.assertEqual(l.state, ref.state) 251 | 252 | def test_index(self): 253 | l = SAPList(nslots=5) 254 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 255 | for idx in l.filled: 256 | self.assertEqual(l.get_index(l[idx]), idx) 257 | self.assertEqual(l.get_index(l[idx].obj), idx) 258 | 259 | def test_remove(self): 260 | l = SAPList(nslots=5) 261 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 262 | ref = SAPList(nslots=5) 263 | ref[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 264 | ref[0] = Slot() 265 | l.remove(0) 266 | self.assertEqual(l.state, ref.state) 267 | 268 | l = SAPList(nslots=5) 269 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 270 | ref = SAPList(nslots=5) 271 | ref[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 272 | ref[2] = Slot() 273 | l.remove(l.slots[2]) 274 | self.assertEqual(l.state, ref.state) 275 | 276 | l = SAPList(nslots=5) 277 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 278 | ref = SAPList(nslots=5) 279 | ref[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 280 | ref[4] = Slot() 281 | l.remove(l.slots[4].obj) 282 | self.assertEqual(l.state, ref.state) 283 | 284 | def test_front_and_behind(self): 285 | l = SAPList(nslots=5) 286 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 287 | self.assertEqual(l.get_infront(l[0]), []) 288 | self.assertEqual(l.get_infront(l[1]), [l[0]]) 289 | self.assertEqual(l.get_infront(l[1], n=2), [l[0]]) 290 | self.assertEqual(l.get_infront(l[2], n=2), [l[1], l[0]]) 291 | self.assertEqual(l.get_infront(l[4], n=4), [l[3], l[2], l[1], l[0]]) 292 | 293 | self.assertEqual(l.get_behind(l[4]), []) 294 | self.assertEqual(l.get_behind(l[0]), [l[1]]) 295 | self.assertEqual(l.get_behind(l[0], 2), [l[1], l[2]]) 296 | self.assertEqual(l.get_behind(l[0], 4), [l[1], l[2], l[3], l[4]]) 297 | self.assertEqual(l.get_behind(l[1], 2), [l[2], l[3]]) 298 | self.assertEqual(l.get_behind(l[3], 2), [l[4]]) 299 | 300 | def test_append(self): 301 | l = SAPList(nslots=5) 302 | l[0, 2, 4] = [Pet(pet_names[x]) for x in range(3)] 303 | l.append(Pet(pet_names[3])) 304 | self.assertEqual(l[1].obj.state, Pet(pet_names[3]).state) 305 | 306 | l.append(Pet(pet_names[4])) 307 | self.assertEqual(l[3].obj.state, Pet(pet_names[4]).state) 308 | 309 | self.assertRaises(Exception, l.append, Pet(pet_names[5])) 310 | 311 | 312 | # %% 313 | 314 | # test = TestLists() 315 | # test.test_remove() 316 | 317 | # %% 318 | -------------------------------------------------------------------------------- /sapai/teams.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import numpy 3 | 4 | from sapai.pets import Pet 5 | from sapai.lists import Slot, SAPList 6 | from numpy import int32, int64 7 | 8 | 9 | class Team(SAPList): 10 | """ 11 | Defines a team class. 12 | 13 | What should be included here that won't be included in just a list of 14 | animals? idk... 15 | 16 | Maybe including interaction between animals. For example, Tiger. Are there 17 | any other interactions? 18 | 19 | """ 20 | 21 | def __init__( 22 | self, 23 | slots=None, 24 | battle=False, 25 | shop=None, 26 | player=None, 27 | pack="StandardPack", 28 | seed_state=None, 29 | ): 30 | slots = slots or [] 31 | 32 | super().__init__(slots, 5, slot_class=TeamSlot) 33 | self._battle = battle 34 | self.seed_state = seed_state 35 | self.slots = [TeamSlot(seed_state=self.seed_state) for _ in range(self.nslots)] 36 | for iter_idx, obj in enumerate(slots): 37 | self[iter_idx] = obj 38 | self[iter_idx]._pet.team = self 39 | self.player = player 40 | self.shop = shop 41 | self.pack = "StandardPack" 42 | 43 | def move(self, sidx, tidx): 44 | """Moves animal from start idx to target idx""" 45 | target = self[tidx] 46 | if not target.empty: 47 | raise Exception("Attempted move to a populated position") 48 | ### Move 49 | self[tidx] = self[sidx] 50 | ### Dereference original position 51 | self[sidx] = TeamSlot(seed_state=self.seed_state) 52 | 53 | def move_forward(self, start_idx=0, end_idx=10): 54 | """ 55 | Adjust the location of the pets in the team, moving them to the furthest 56 | possible forward location using a recursive function. The arg idx may 57 | be provided to indicate the first index that is allowed to move 58 | forward. 59 | 60 | """ 61 | empty_idx = [] 62 | filled_idx = [] 63 | for iter_idx, slot in enumerate(self): 64 | if slot.empty: 65 | empty_idx.append(iter_idx) 66 | else: 67 | filled_idx.append(iter_idx) 68 | if len(empty_idx) > 0: 69 | ### Only need to consider the first empty position 70 | empty_idx = empty_idx[0] 71 | 72 | ### Find first pet that can fill this empty position 73 | found = False 74 | for temp_idx in filled_idx: 75 | if temp_idx < start_idx: 76 | continue 77 | if temp_idx >= end_idx: 78 | continue 79 | if empty_idx < temp_idx: 80 | found = True 81 | ### Move pet 82 | self.move(temp_idx, empty_idx) 83 | break 84 | 85 | ### If a pet was moved, call recurisvely 86 | if found: 87 | self.move_forward(start_idx, end_idx) 88 | 89 | return 90 | 91 | def move_backward(self, **kwargs): 92 | """ 93 | Adjust the location of the pets in the team, moving them to the furthest 94 | possible backward location using a recursive function. 95 | 96 | This is useful for summoning purposes 97 | 98 | """ 99 | empty_idx = [] 100 | filled_idx = [] 101 | for iter_idx, slot in enumerate(self): 102 | if slot.empty: 103 | empty_idx.append(iter_idx) 104 | else: 105 | filled_idx.append(iter_idx) 106 | if len(empty_idx) > 0: 107 | ### Only need to consider the last empty position 108 | empty_idx = empty_idx[-1] 109 | 110 | ### Find first pet that can fill this empty position 111 | found = False 112 | for start_idx in filled_idx[::-1]: 113 | if empty_idx > start_idx: 114 | found = True 115 | ### Move pet 116 | self.move(start_idx, empty_idx) 117 | break 118 | 119 | ### If a pet was moved, call recurisvely 120 | if found: 121 | self.move_backward() 122 | 123 | return 124 | 125 | def remove(self, obj): 126 | if type(obj) == int: 127 | self.slots[obj] = TeamSlot(seed_state=self.seed_state) 128 | elif isinstance(obj, TeamSlot): 129 | found = False 130 | for iter_idx, temp_slot in enumerate(self.slots): 131 | if temp_slot == obj: 132 | found_idx = iter_idx 133 | found = True 134 | if not found: 135 | raise Exception(f"Remove {obj} not found") 136 | self.slots[found_idx] = TeamSlot(seed_state=self.seed_state) 137 | elif isinstance(obj, Pet): 138 | found = False 139 | for iter_idx, temp_slot in enumerate(self.slots): 140 | temp_pet = temp_slot.pet 141 | if temp_pet == obj: 142 | found_idx = iter_idx 143 | found = True 144 | if not found: 145 | raise Exception(f"Remove {obj} not found") 146 | self.slots[found_idx] = TeamSlot(seed_state=self.seed_state) 147 | else: 148 | raise Exception(f"Object of type {type(obj)} not recognized") 149 | 150 | def check_friend(self, obj): 151 | if isinstance(obj, TeamSlot): 152 | found = False 153 | for iter_idx, temp_slot in enumerate(self.slots): 154 | if temp_slot == obj: 155 | found = True 156 | return found 157 | elif isinstance(obj, Pet): 158 | found = False 159 | for iter_idx, temp_slot in enumerate(self.slots): 160 | temp_pet = temp_slot.pet 161 | if temp_pet == obj: 162 | found = True 163 | return found 164 | else: 165 | raise Exception(f"Object of type {type(obj)} not recognized") 166 | 167 | def get_idx(self, obj): 168 | if isinstance(obj, TeamSlot): 169 | found = False 170 | for iter_idx, temp_slot in enumerate(self.slots): 171 | if temp_slot == obj: 172 | found_idx = iter_idx 173 | found = True 174 | if not found: 175 | raise Exception(f"get_idx {obj} not found") 176 | return found_idx 177 | elif isinstance(obj, Pet): 178 | found = False 179 | for iter_idx, temp_slot in enumerate(self.slots): 180 | temp_pet = temp_slot.pet 181 | if temp_pet == obj: 182 | found_idx = iter_idx 183 | found = True 184 | if not found: 185 | raise Exception(f"get_idx {obj} not found") 186 | return found_idx 187 | elif type(obj) == int: 188 | return obj 189 | elif isinstance(obj, int64): 190 | ### For numpy int 191 | return obj 192 | elif isinstance(obj, int32): 193 | ### For numpy int 194 | return obj 195 | else: 196 | raise Exception(f"Object of type {type(obj)} not recognized") 197 | 198 | def index(self, obj): 199 | return self.get_idx(obj) 200 | 201 | def get_fidx(self): 202 | ### Get possible indices for each team 203 | fidx = [] 204 | for iter_idx, temp_slot in enumerate(self): 205 | if not temp_slot.empty: 206 | ### Skiped if health is less than 0 207 | if temp_slot.pet.health > 0: 208 | fidx.append(iter_idx) 209 | return fidx 210 | 211 | def get_ahead(self, obj, n=1): 212 | pet_idx = self.get_idx(obj) 213 | fidx = [] 214 | for iter_idx, temp_slot in enumerate(self): 215 | if not temp_slot.empty: 216 | fidx.append(iter_idx) 217 | chosen_idx = [] 218 | for temp_idx in fidx: 219 | if temp_idx < pet_idx: 220 | chosen_idx.append(temp_idx) 221 | ret_pets = [] 222 | for temp_idx in chosen_idx[::-1]: 223 | ret_pets.append(self[temp_idx].pet) 224 | if len(ret_pets) >= n: 225 | break 226 | return ret_pets 227 | 228 | def get_behind(self, obj, n=1): 229 | pet_idx = self.get_idx(obj) 230 | fidx = [] 231 | for iter_idx, temp_slot in enumerate(self): 232 | if not temp_slot.empty: 233 | fidx.append(iter_idx) 234 | chosen = [] 235 | for temp_idx in fidx: 236 | if temp_idx > pet_idx: 237 | chosen.append(self.slots[temp_idx]) 238 | return chosen[0:n] 239 | 240 | def get_empty(self): 241 | empty_idx = [] 242 | for iter_idx, temp_slot in enumerate(self): 243 | if temp_slot.empty: 244 | empty_idx.append(iter_idx) 245 | return empty_idx 246 | 247 | def append(self, obj): 248 | obj = TeamSlot(obj, seed_state=self.seed_state) 249 | n = len(self) 250 | if n == len(self.slots): 251 | raise Exception("Attempted to append to a full team") 252 | empty_idx = self.get_empty() 253 | if len(empty_idx) == 0: 254 | raise Exception("This should not be possible") 255 | self.slots[empty_idx[0]] = obj 256 | 257 | def check_lvl3(self): 258 | for slot in self.slots: 259 | if slot.empty: 260 | continue 261 | if slot.pet.level == 3: 262 | return True 263 | return False 264 | 265 | @property 266 | def battle(self): 267 | return self._battle 268 | 269 | def __iter__(self): 270 | yield from self.slots 271 | 272 | def __len__(self): 273 | count = 0 274 | for temp_slot in self.slots: 275 | if not temp_slot.empty: 276 | count += 1 277 | return count 278 | 279 | def __getitem__(self, idx): 280 | return self.slots[idx] 281 | 282 | def __setitem__(self, idx, obj): 283 | if isinstance(obj, Pet): 284 | self.slots[idx] = TeamSlot(obj, seed_state=self.seed_state) 285 | elif isinstance(obj, TeamSlot): 286 | self.slots[idx] = obj 287 | elif type(obj) == str or type(obj) == numpy.str_: 288 | self.slots[idx] = TeamSlot(obj, seed_state=self.seed_state) 289 | else: 290 | raise Exception(f"Tried setting a team slot with type {type(obj).__name__}") 291 | 292 | def __repr__(self): 293 | repr_str = "" 294 | for iter_idx, slot in enumerate(self.slots): 295 | repr_str += f"{iter_idx}: {slot} \n " 296 | return repr_str 297 | 298 | def copy(self): 299 | return Team( 300 | [x.copy() for x in self], 301 | self.battle, 302 | self.player, 303 | seed_state=self.seed_state, 304 | ) 305 | 306 | @property 307 | def state(self): 308 | ### seed_state doesn't need to be stored for Team because the seed_state 309 | ### is stored by pets 310 | state_dict = { 311 | "type": "Team", 312 | "battle": self.battle, 313 | "team": [x.state for x in self.slots], 314 | "pack": self.pack, 315 | } 316 | return state_dict 317 | 318 | @classmethod 319 | def from_state(cls, state): 320 | team = [TeamSlot.from_state(x) for x in state["team"]] 321 | return cls( 322 | slots=team, 323 | battle=state["battle"], 324 | shop=None, 325 | player=None, 326 | pack=state["pack"], 327 | ) 328 | 329 | 330 | class TeamSlot(Slot): 331 | def __init__(self, obj=None, seed_state=None): 332 | super().__init__() 333 | self.seed_state = seed_state 334 | if isinstance(obj, Pet): 335 | self.obj = obj 336 | elif isinstance(obj, TeamSlot): 337 | self.obj = obj.pet 338 | elif obj is None: 339 | self.obj = Pet(seed_state=self.seed_state) 340 | elif type(obj) == str or type(obj) == numpy.str_: 341 | self.obj = Pet(obj, seed_state=self.seed_state) 342 | else: 343 | raise Exception( 344 | f"Tried initalizing TeamSlot with type {type(obj).__name__}" 345 | ) 346 | 347 | @property 348 | def _pet(self): 349 | return self._obj 350 | 351 | @property 352 | def pet(self): 353 | return self.obj 354 | 355 | @property 356 | def empty(self): 357 | return self.obj.name == "pet-none" 358 | 359 | @property 360 | def attack(self): 361 | return self.obj.attack 362 | 363 | @property 364 | def health(self): 365 | return self.obj.health 366 | 367 | @property 368 | def ability(self): 369 | return self.obj.ability 370 | 371 | @property 372 | def level(self): 373 | return self.obj.level 374 | 375 | def __repr__(self): 376 | if self.obj.name == "pet-none": 377 | return "< Slot EMPTY >" 378 | else: 379 | pet_repr = str(self.obj) 380 | pet_repr = pet_repr[2:-2] 381 | return f"< Slot {pet_repr} >" 382 | 383 | def copy(self): 384 | return TeamSlot(self.obj.copy(), seed_state=self.seed_state) 385 | 386 | @property 387 | def state(self): 388 | ### seed_state doesn't need to be stored for TeamSlot because the 389 | ### seed_state is stored by pets 390 | state_dict = { 391 | "type": "TeamSlot", 392 | "pet": self.obj.state, 393 | } 394 | return state_dict 395 | 396 | @classmethod 397 | def from_state(cls, state): 398 | pet = Pet.from_state(state["pet"]) 399 | return cls(pet) 400 | 401 | 402 | # %% 403 | -------------------------------------------------------------------------------- /sapai/lists.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from collections.abc import Iterable 3 | import numpy as np 4 | 5 | 6 | class Slot: 7 | """ 8 | Implements basic behavior of a slot in a SAPList 9 | 10 | """ 11 | 12 | def __init__(self, obj=None): 13 | if isinstance(obj, Slot): 14 | obj = obj.obj 15 | self.obj = obj 16 | 17 | def __repr__(self): 18 | if self.obj is None: 19 | name = "EMPTY" 20 | elif self.obj.name == "pet-none" or self.obj.name == "food-none": 21 | name = "EMPTY" 22 | else: 23 | name = str(self._obj)[2:-2] 24 | return f"< Slot {name} >" 25 | 26 | @property 27 | def obj(self): 28 | return self._obj 29 | 30 | @obj.setter 31 | def obj(self, obj): 32 | if obj is not None: 33 | if isinstance(obj, type(self)): 34 | obj = obj.obj 35 | if ( 36 | not hasattr(obj, "state") 37 | or not hasattr(obj, "from_state") 38 | or not hasattr(obj, "name") 39 | ): 40 | raise Exception( 41 | f"Input object {obj} must have a name, a state method, and a from_state method" 42 | ) 43 | self._obj = obj 44 | 45 | @property 46 | def empty(self): 47 | """ 48 | Returns if the given slot is empty 49 | """ 50 | return self._obj is None 51 | 52 | @obj.deleter 53 | def obj(self): 54 | del self.obj 55 | self._obj = None 56 | 57 | @property 58 | def state(self): 59 | ### seed_state doesn't need to be stored for TeamSlot because the 60 | ### seed_state is stored by pets 61 | state_dict = { 62 | "type": "Slot", 63 | } 64 | if not self.empty: 65 | state_dict["obj"] = self.obj.state 66 | return state_dict 67 | 68 | @classmethod 69 | def from_state(cls, state): 70 | from sapai.compress import state2obj 71 | 72 | if "obj" in state: 73 | obj = state2obj(state["obj"]) 74 | return cls(obj) 75 | 76 | 77 | class SAPList: 78 | """ 79 | Implements required methods and behaviors for lists of slots for SAP 80 | 81 | """ 82 | 83 | def __init__(self, slots=None, nslots=None, slot_class=Slot): 84 | slots = slots or [] 85 | 86 | self.slot_class = slot_class 87 | self._slots = [] 88 | self._nslots = None 89 | self.slots = slots 90 | if nslots is not None: 91 | self.nslots = nslots 92 | 93 | def __len__(self): 94 | """ 95 | Returns the number of filled slots 96 | """ 97 | return len(self.slots) 98 | 99 | def __iter__(self): 100 | yield from self.slots 101 | 102 | def __getitem__(self, idx): 103 | return self.slots[idx] 104 | 105 | def __setitem__(self, idx, obj): 106 | if isinstance(obj, Iterable): 107 | for i, temp_idx in enumerate(idx): 108 | self._slots[temp_idx] = Slot(obj[i]) 109 | elif isinstance(idx, (int, np.integer)): 110 | self._slots[idx] = Slot(obj) 111 | else: 112 | raise Exception(f"Index provided must be int, given {type(idx)}") 113 | 114 | def __repr__(self): 115 | repr_str = "" 116 | for iter_idx, slot in enumerate(self.slots): 117 | repr_str += f"{iter_idx}: {slot} \n " 118 | return repr_str 119 | 120 | @property 121 | def slots(self): 122 | return self._slots 123 | 124 | @slots.setter 125 | def slots(self, objs): 126 | if isinstance(self.slot_class, partial): 127 | test_type = self.slot_class.func 128 | else: 129 | test_type = self.slot_class 130 | if isinstance(objs, Iterable): 131 | temp_slots = [] 132 | for obj in objs: 133 | if not isinstance(obj, test_type): 134 | temp_slots.append(self.slot_class(obj)) 135 | else: 136 | temp_slots.append(obj) 137 | self._slots = temp_slots 138 | else: 139 | temp_slots = [] 140 | if not isinstance(objs, test_type): 141 | temp_slots.append(self.slot_class(objs)) 142 | else: 143 | temp_slots.append(objs) 144 | self._slots = temp_slots 145 | if self.nslots is not None: 146 | self.nslots = self._nslots 147 | 148 | @property 149 | def nslots(self): 150 | """ 151 | Number of slots in the slotlist 152 | """ 153 | return self._nslots 154 | 155 | @nslots.setter 156 | def nslots(self, length): 157 | """ 158 | Sets nslots and confirms the number of Slots in _slots 159 | """ 160 | self._nslots = length 161 | if self._nslots is None: 162 | return 163 | if not isinstance(length, (int, np.integer)): 164 | raise Exception(f"SAPList nslots must be int, given {type(length)}") 165 | if length < 0: 166 | raise Exception(f"SAPList nslots must be 0 or greater, given {length}") 167 | self._nslots = int(self._nslots) 168 | if len(self._slots) < self.nslots: 169 | [ 170 | self._slots.append(self.slot_class()) 171 | for _ in range(self.nslots - len(self._slots)) 172 | ] 173 | elif len(self._slots) > self.nslots: 174 | self._slots = self._slots[: self.nslots] 175 | 176 | @property 177 | def left(self): 178 | return self._slots[0] 179 | 180 | @property 181 | def right(self): 182 | return self._slots[-1] 183 | 184 | @property 185 | def leftmost(self): 186 | """ 187 | Returns leftmost slot that is not empty 188 | """ 189 | ret = None 190 | idx = None 191 | for i, slot in enumerate(self): 192 | if slot.empty: 193 | continue 194 | ret = slot 195 | idx = i 196 | break 197 | return ret, idx 198 | 199 | @property 200 | def rightmost(self): 201 | """ 202 | Returns rightmost slot that is not empty 203 | """ 204 | ret = None 205 | for i, slot in enumerate(self[::-1]): 206 | if slot.empty: 207 | continue 208 | ret = slot 209 | idx = len(self) - 1 - i 210 | break 211 | return ret, idx 212 | 213 | @property 214 | def empty(self): 215 | """ 216 | Return the indices of empty slots 217 | """ 218 | idx = [] 219 | for i, temp_slot in enumerate(self.slots): 220 | if temp_slot.empty: 221 | idx.append(i) 222 | return idx 223 | 224 | @property 225 | def filled(self): 226 | """ 227 | Return the indices of non-empty slots 228 | """ 229 | idx = [] 230 | for i, temp_slot in enumerate(self.slots): 231 | if not temp_slot.empty: 232 | idx.append(i) 233 | return idx 234 | 235 | def get_left(self, n=1): 236 | """ 237 | Return the n left-most slots 238 | """ 239 | return self._slots[:n] 240 | 241 | def get_right(self, n=1): 242 | """ 243 | Return the n right-most slots 244 | """ 245 | return self._slots[::-1][:n] 246 | 247 | def move(self, sidx, tidx): 248 | """ 249 | Move object from start idx to target idx 250 | """ 251 | target = self[tidx] 252 | if not target.empty: 253 | raise Exception("Attempted move to a populated position") 254 | self[tidx] = self[sidx] 255 | ### Dereference original position 256 | self[sidx] = self.slot_class() 257 | 258 | def move_right(self): 259 | """ 260 | Move all entries in SlotList after index i to the right 261 | """ 262 | self.move_backward() 263 | 264 | def move_left(self, sidx=0, eidx=-1): 265 | """ 266 | Move all entires in SlotList after index i to the left 267 | """ 268 | self.move_forward(sidx, eidx) 269 | 270 | def move_forward(self, sidx=0, eidx=-1): 271 | """ 272 | Adjust the location of the pets in the team, moving them as far forward 273 | as sidx from eidx using a recursive function. 274 | 275 | """ 276 | if eidx == -1: 277 | eidx = len(self) 278 | if sidx >= eidx: 279 | raise Exception(f"End idx {sidx} must be greater than start idx {eidx}") 280 | 281 | empty_idx = [x for x in self.empty if x >= sidx and x <= eidx] 282 | filled_idx = [x for x in self.filled if x > sidx and x <= eidx] 283 | 284 | if len(empty_idx) > 0: 285 | ### Only need to consider the first empty position 286 | empty_idx = empty_idx[0] 287 | 288 | ### Find first pet that can fill this empty position 289 | found = False 290 | for temp_idx in filled_idx: 291 | if empty_idx < temp_idx: 292 | found = True 293 | ### Move pet 294 | self.move(temp_idx, empty_idx) 295 | break 296 | 297 | ### If a pet was moved, call recurisvely 298 | if found: 299 | self.move_forward(sidx, eidx) 300 | 301 | return 302 | 303 | def move_backward(self, sidx=0, eidx=-1): 304 | """ 305 | Adjust the location of the pets in the team, moving them to the furthest 306 | possible backward location using a recursive function. 307 | 308 | """ 309 | if eidx == -1: 310 | eidx = len(self) 311 | if sidx >= eidx: 312 | raise Exception(f"End idx {sidx} must be greater than start idx {eidx}") 313 | 314 | empty_idx = [x for x in self.empty if x > sidx and x <= eidx] 315 | filled_idx = [x for x in self.filled if x >= sidx and x <= eidx] 316 | 317 | if len(empty_idx) > 0: 318 | ### Only need to consider the last empty position 319 | empty_idx = empty_idx[-1] 320 | 321 | ### Find first pet that can fill this empty position 322 | found = False 323 | for temp_idx in filled_idx[::-1]: 324 | if empty_idx > temp_idx: 325 | found = True 326 | ### Move pet 327 | self.move(temp_idx, empty_idx) 328 | break 329 | 330 | ### If a pet was moved, call recurisvely 331 | if found: 332 | self.move_backward(sidx, eidx) 333 | 334 | return 335 | 336 | def get_index(self, obj): 337 | """ 338 | Return the index of an input slot or item 339 | """ 340 | found_idx = None 341 | if isinstance(obj, Slot): 342 | for iter_idx, temp_slot in enumerate(self.slots): 343 | if temp_slot == obj: 344 | found_idx = iter_idx 345 | else: 346 | for iter_idx, temp_slot in enumerate(self.slots): 347 | if temp_slot.obj == obj: 348 | found_idx = iter_idx 349 | return found_idx 350 | 351 | def remove(self, obj): 352 | """ 353 | Remove by slot, item, or index from SAPList 354 | """ 355 | found = False 356 | if isinstance(obj, (int, np.integer)): 357 | self.slots[obj] = self.slot_class() 358 | found = True 359 | elif isinstance(obj, Slot): 360 | for iter_idx, temp_slot in enumerate(self.slots): 361 | if temp_slot == obj: 362 | self.slots[iter_idx] = self.slot_class() 363 | found = True 364 | break 365 | else: 366 | for iter_idx, temp_slot in enumerate(self.slots): 367 | if temp_slot.obj == obj: 368 | self.slots[iter_idx] = self.slot_class() 369 | found = True 370 | break 371 | 372 | if not found: 373 | raise Exception(f"Object {obj} not found in SAPList {self}") 374 | 375 | def get_behind(self, obj, n=1): 376 | """ 377 | Return the n slots behind the given object 378 | """ 379 | if isinstance(obj, (int, np.integer)): 380 | idx = obj 381 | else: 382 | idx = self.get_index(obj) 383 | 384 | return self.slots[idx + 1 : idx + n + 1] 385 | 386 | def get_infront(self, obj, n=1): 387 | """ 388 | Return the n slots infront of the given object 389 | """ 390 | if isinstance(obj, (int, np.integer)): 391 | idx = obj 392 | else: 393 | idx = self.get_index(obj) 394 | 395 | start_idx = max(idx - n, 0) 396 | end_idx = idx 397 | return self.slots[start_idx:end_idx][::-1] 398 | 399 | def append(self, obj): 400 | """ 401 | Adds object to first open slot 402 | """ 403 | if len(self.empty) > 0: 404 | self.slots[self.empty[0]] = Slot(obj) 405 | else: 406 | raise Exception("Attempted to append to full SAPList") 407 | 408 | @property 409 | def state(self): 410 | """ 411 | Return the state defining the current slotlist 412 | """ 413 | state_dict = { 414 | "type": "SAPList", 415 | "nslots": self.nslots, 416 | "slots": [x.state for x in self.slots], 417 | } 418 | return state_dict 419 | 420 | @classmethod 421 | def from_state(cls, state): 422 | """ 423 | Build a SlotList from the given state dictionary 424 | """ 425 | from sapai.compress import state2obj 426 | 427 | slots = [state2obj(x) for x in state["slots"]] 428 | return cls(slots, nslots=state["nslots"]) 429 | -------------------------------------------------------------------------------- /tests/test_battles.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import unittest 3 | 4 | import numpy as np 5 | 6 | from sapai import * 7 | from sapai.battle import Battle 8 | from sapai.graph import graph_battle 9 | from sapai.compress import * 10 | 11 | 12 | class TestBattles(unittest.TestCase): 13 | def test_multi_hurt(self): 14 | t0 = Team(["badger", "camel", "fish"]) 15 | t0[0].pet._health = 1 16 | t0[0].pet._attack = 1 17 | t1 = Team(["cricket", "horse", "mosquito", "tiger"]) 18 | 19 | b = Battle(t0, t1) 20 | b.battle() 21 | 22 | def test_multi_faint(self): 23 | t0 = Team(["badger", "camel", "fish"]) 24 | t0[0].pet._health = 1 25 | t0[0].pet._attack = 5 26 | t1 = Team(["cricket", "horse", "mosquito", "tiger"]) 27 | 28 | b = Battle(t0, t1) 29 | b.battle() 30 | 31 | def test_before_and_after_attack(self): 32 | t0 = Team(["elephant", "snake", "dragon", "fish"]) 33 | t1 = Team(["cricket", "horse", "fly", "tiger"]) 34 | t0[2]._health = 50 35 | 36 | b = Battle(t0, t1) 37 | b.battle() 38 | t0 = Team(["elephant", "snake", "dragon", "fish"]) 39 | t1 = Team(["cricket", "horse", "fly", "tiger"]) 40 | t0[2]._health = 50 41 | 42 | b = Battle(t0, t1) 43 | b.battle() 44 | 45 | def test_rhino_test(self): 46 | t0 = Team(["horse", "horse", "horse", "horse"]) 47 | t1 = Team(["rhino", "tiger"]) 48 | 49 | b = Battle(t0, t1) 50 | b.battle() 51 | 52 | def test_hippo_test(self): 53 | t0 = Team(["horse", "horse", "horse", "horse"]) 54 | t1 = Team(["hippo", "tiger"]) 55 | 56 | b = Battle(t0, t1) 57 | b.battle() 58 | 59 | def test_whale_without_swallow_target(self): 60 | team1 = Team([Pet("fish")]) 61 | team2 = Team([Pet("whale"), Pet("hedgehog"), Pet("fish"), Pet("rabbit")]) 62 | 63 | test_battle = Battle(team1, team2) 64 | test_battle.battle() 65 | 66 | def test_cat_battle(self): 67 | team1 = Team([Pet("fish")]) 68 | team2 = Team([Pet("cat")]) 69 | 70 | test_battle = Battle(team1, team2) 71 | test_battle.battle() 72 | 73 | def test_multiple_hedgehog(self): 74 | team1 = Team([Pet("fish"), Pet("hedgehog")]) 75 | team2 = Team([Pet("elephant"), Pet("hedgehog")]) 76 | 77 | test_battle = Battle(team1, team2) 78 | test_battle.battle() 79 | 80 | def test_whale_parrot_swallow(self): 81 | team1 = Team([Pet("whale"), Pet("parrot")]) 82 | team2 = Team([Pet("fish"), "dragon"]) 83 | 84 | player1 = Player(team=team1) 85 | player2 = Player(team=team2) 86 | 87 | player1.end_turn() 88 | player2.end_turn() 89 | 90 | test_battle = Battle(player1.team, player2.team) 91 | test_battle.battle() 92 | 93 | def test_caterpillar_order_high_attack(self): 94 | cp = Pet("caterpillar") 95 | cp.level = 3 96 | cp._attack = 5 # 1 more than dolphin 97 | cp._health = 7 98 | t = Team([cp, "dragon"]) 99 | t2 = Team(["dolphin", "dragon"]) 100 | b = Battle(t, t2) 101 | r = b.battle() 102 | # print(b.battle_history) # caterpillar evolves first, dolphin snipes butterfly, 1v2 loss 103 | self.assertEqual(r, 1) 104 | 105 | def test_caterpillar_order_low_attack(self): 106 | cp = Pet("caterpillar") 107 | cp.level = 3 108 | cp._attack = 1 109 | cp._health = 7 110 | t = Team([cp, "dragon"]) 111 | t2 = Team(["dolphin", "dragon"]) 112 | b = Battle(t, t2) 113 | r = b.battle() 114 | # print(b.battle_history) # dolphin hits caterpillar, caterpillar evolves, copies dragon, win 115 | self.assertEqual(r, 0) 116 | 117 | def test_dodo(self): 118 | dodo = Pet("dodo") 119 | dodo.level = 3 120 | dodo._attack = 10 121 | team1 = Team([Pet("leopard"), dodo]) 122 | 123 | fish = Pet("fish") 124 | fish._attack = 5 125 | fish._health = 20 126 | team2 = Team([fish]) 127 | 128 | test_battle = Battle(team1, team2) 129 | result = test_battle.battle() 130 | 131 | # dodo adds enough attack for leopard to kill fish 132 | self.assertEqual(result, 0) 133 | 134 | def test_ant_in_battle(self): 135 | team1 = Team([Pet("ant"), Pet("fish")]) 136 | team2 = Team([Pet("camel")]) 137 | 138 | test_battle = Battle(team1, team2) 139 | result = test_battle.battle() 140 | self.assertEqual(result, 0) 141 | 142 | def test_horse_in_battle(self): 143 | team1 = Team([Pet("cricket"), Pet("horse")]) 144 | team2 = Team([Pet("camel")]) 145 | 146 | test_battle = Battle(team1, team2) 147 | result = test_battle.battle() 148 | self.assertEqual(test_battle.t0.empty, [0, 1, 2, 3, 4]) 149 | self.assertEqual(test_battle.t1[0].health, 1) 150 | 151 | def test_horse_with_bee_in_battle(self): 152 | cricket = Pet("cricket") 153 | cricket.status = "status-honey-bee" 154 | team1 = Team([cricket, Pet("horse")]) 155 | fish = Pet("fish") 156 | fish._health = 5 157 | team2 = Team([fish, Pet("beaver")]) 158 | 159 | test_battle = Battle(team1, team2) 160 | result = test_battle.battle() 161 | self.assertEqual(result, 2) 162 | 163 | def test_mosquito_in_battle(self): 164 | team1 = Team([Pet("mosquito")]) 165 | team2 = Team([Pet("pig")]) 166 | 167 | test_battle = Battle(team1, team2) 168 | result = test_battle.battle() 169 | self.assertEqual(result, 0) 170 | 171 | def test_blowfish_pingpong(self): 172 | # they hit eachother once, rest of the battle is constant hurt triggers until they both faint 173 | b1 = Pet("blowfish") 174 | b1._attack = 1 175 | b1._health = 50 176 | 177 | b2 = Pet("blowfish") 178 | b2._attack = 1 179 | b2._health = 50 180 | 181 | b = Battle(Team([b1]), Team([b2])) 182 | r = b.battle() 183 | self.assertTrue( 184 | "attack 1" not in b.battle_history 185 | ) # they attack eachother, then keep using hurt_triggers until one of them dies, should never reach a 2nd attack phase 186 | 187 | def test_elephant_blowfish(self): 188 | # blowfish snipes first fish in 'before-attack' phase of elephant, leaving elephant without a target to attack normally 189 | # then snipes second fish in next turn's 'before attack' 190 | state = np.random.RandomState(seed=1).get_state() 191 | 192 | e1 = Pet("elephant") 193 | e1._attack = 1 194 | e1._health = 5 195 | 196 | b1 = Pet("blowfish", seed_state=state) 197 | b1._attack = 1 198 | b1._health = 5 199 | 200 | f1 = Pet("fish") 201 | f1._attack = 50 202 | f1._health = 1 203 | f1.status = "status-splash-attack" 204 | 205 | f2 = Pet("fish") 206 | f2._attack = 50 207 | f2._health = 1 208 | f2.status = "status-splash-attack" 209 | 210 | b = Battle(Team([e1, b1]), Team([f1, f2])) 211 | r = b.battle() 212 | self.assertEqual(r, 0) 213 | 214 | def test_hedgehog_blowfish_camel_hurt_team(self): 215 | # standard hedgehog blowfish camel teams facing off against eachother 216 | # lots of hurt triggers going off within one turn 217 | state1 = np.random.RandomState(seed=2).get_state() 218 | state2 = np.random.RandomState(seed=2).get_state() 219 | 220 | bf1 = Pet("blowfish", seed_state=state1) 221 | bf1._attack = 20 222 | bf1._health = 20 223 | bf1.level = 3 224 | bf1.status = "status-garlic-armor" 225 | 226 | c1 = Pet("camel") 227 | c1._attack = 20 228 | c1._health = 20 229 | c1.level = 2 230 | c1.status = "status-garlic-armor" 231 | 232 | hh1 = Pet("hedgehog") 233 | hh2 = Pet("hedgehog") 234 | 235 | bf2 = Pet("blowfish", seed_state=state2) 236 | bf2._attack = 20 237 | bf2._health = 20 238 | bf2.level = 3 239 | bf2.status = "status-garlic-armor" 240 | 241 | c2 = Pet("camel") 242 | c2._attack = 20 243 | c2._health = 20 244 | c2.level = 2 245 | c2.status = "status-garlic-armor" 246 | 247 | hh3 = Pet("hedgehog") 248 | hh4 = Pet("hedgehog") 249 | 250 | b = Battle(Team([hh1, hh2, c1, bf1]), Team([hh3, hh4, c2, bf2])) 251 | r = b.battle() 252 | self.assertEqual(r, 2) 253 | 254 | def test_hedgehog_vs_honey(self): 255 | hh1 = Pet("hedgehog") 256 | hh2 = Pet("hedgehog") 257 | hh3 = Pet("hedgehog") 258 | hh4 = Pet("hedgehog") 259 | hh5 = Pet("hedgehog") 260 | f1 = Pet("fish") 261 | f1.status = "status-honey-bee" 262 | 263 | b = Battle(Team([hh1, hh2, hh3, hh4, hh5]), Team([f1])) 264 | r = b.battle() 265 | # ability triggers always go before status triggers 266 | # fish wins since honey bee spawns at the end of the turn, after all faint triggers are completed 267 | self.assertEqual(r, 1) 268 | 269 | def test_hedgehog_vs_mushroom(self): 270 | hh1 = Pet("hedgehog") 271 | hh2 = Pet("hedgehog") 272 | hh3 = Pet("hedgehog") 273 | hh4 = Pet("hedgehog") 274 | hh5 = Pet("hedgehog") 275 | f1 = Pet("fish") 276 | f1.status = "status-extra-life" 277 | 278 | b = Battle(Team([hh1, hh2, hh3, hh4, hh5]), Team([f1])) 279 | r = b.battle() 280 | # ability triggers always go before status triggers 281 | # fish wins since mushroom spawns at the end of the turn, after all faint triggers are completed 282 | self.assertEqual(r, 1) 283 | 284 | def test_mushroom_scorpion(self): 285 | scorpion = Pet("scorpion") 286 | scorpion.status = "status-extra-life" 287 | b = Battle(Team([scorpion]), Team(["dragon"])) 288 | r = b.battle() 289 | self.assertEqual(r, 2) # draw since scorpion respawns with poison. 290 | 291 | def test_badger_draws(self): 292 | # normal 1v1 293 | b1 = Pet("badger") 294 | b2 = Pet("badger") 295 | b = Battle(Team([b1]), Team([b2])) 296 | r = b.battle() 297 | self.assertEqual(r, 2) 298 | 299 | # 1 survives, enemy ability kills 300 | b1 = Pet("badger") 301 | b1._health = 6 302 | b2 = Pet("badger") 303 | b = Battle(Team([b1]), Team([b2])) 304 | r = b.battle() 305 | self.assertEqual(r, 2) 306 | 307 | # normal honey 1v1 308 | hb1 = Pet("badger") 309 | hb1.status = "status-honey-bee" 310 | hb2 = Pet("badger") 311 | hb2.status = "status-honey-bee" 312 | b = Battle(Team([hb1]), Team([hb2])) 313 | r = b.battle() 314 | self.assertEqual(r, 2) 315 | 316 | # 1 survives, enemy ability kills 317 | hb1 = Pet("badger") 318 | hb1.status = "status-honey-bee" 319 | hb1._health = 6 320 | hb2 = Pet("badger") 321 | hb2.status = "status-honey-bee" 322 | b = Battle(Team([hb1]), Team([hb2])) 323 | r = b.battle() 324 | self.assertEqual(r, 2) 325 | 326 | # 1 survives, enemey ability kills, even with less attack priority should NOT be able to hit bee with ability 327 | hb1 = Pet("badger") 328 | hb1.status = "status-honey-bee" 329 | hb1._attack = 4 330 | hb1._health = 6 331 | hb2 = Pet("badger") 332 | hb2.status = "status-honey-bee" 333 | b = Battle(Team([hb1]), Team([hb2])) 334 | r = b.battle() 335 | self.assertEqual(r, 2) 336 | 337 | # badger with less attack can kill zombie-cricket 338 | b1 = Pet("badger") 339 | c1 = Pet("cricket") 340 | c1._attack = 6 341 | b = Battle(Team([b1]), Team([c1])) 342 | r = b.battle() 343 | self.assertEqual(r, 2) 344 | 345 | # badger with higher priority hits nothing with ability, zombie-cricket spanws and bee spawns 346 | hb1 = Pet("badger") 347 | hb1.status = "status-honey-bee" 348 | c1 = Pet("cricket") 349 | c1._attack = 4 350 | b = Battle(Team([hb1]), Team([c1])) 351 | r = b.battle() 352 | self.assertEqual(r, 2) 353 | 354 | def test_badger_wins(self): 355 | # bee win 356 | hb1 = Pet("badger") 357 | hb1.status = "status-honey-bee" 358 | b2 = Pet("badger") 359 | b = Battle(Team([hb1]), Team([b2])) 360 | r = b.battle() 361 | self.assertEqual(r, 0) 362 | 363 | # badger with higher priority hits nothing with ability, zombie-cricket spanws and wins 364 | c1 = Pet("cricket") 365 | c1._attack = 4 366 | b1 = Pet("badger") 367 | b = Battle(Team([c1]), Team([b1])) 368 | r = b.battle() 369 | self.assertEqual(r, 0) 370 | 371 | # badger with less attack can kill zombie-cricket, then bee spawns 372 | hb1 = Pet("badger") 373 | hb1.status = "status-honey-bee" 374 | c1 = Pet("cricket") 375 | c1._attack = 6 376 | b = Battle(Team([hb1]), Team([c1])) 377 | r = b.battle() 378 | self.assertEqual(r, 0) 379 | 380 | def test_rat_summons_at_front(self): 381 | team1 = Team(["rat", "blowfish"]) 382 | fish = Pet("fish") 383 | fish._attack = 5 384 | big_attack_pet = Pet("beaver") 385 | big_attack_pet._attack = 50 386 | team2 = Team([fish, big_attack_pet]) 387 | 388 | test_battle = Battle(team1, team2) 389 | result = test_battle.battle() 390 | self.assertEqual(result, 0) 391 | 392 | def test_peacock(self): 393 | ### Check that peacock attack is correct after battle 394 | 395 | ### Check peacock attack after elephant for all three levels 396 | 397 | ### Check peacock after headgehog on both teams 398 | 399 | ### Implement later with others 400 | pass 401 | 402 | 403 | # %% 404 | -------------------------------------------------------------------------------- /sapai/player.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sapai import data 3 | from sapai.battle import Battle 4 | from sapai.effects import ( 5 | RespawnPet, 6 | SummonPet, 7 | SummonRandomPet, 8 | get_effect_function, 9 | get_target, 10 | ) 11 | 12 | import sapai.shop 13 | from sapai.shop import Shop, ShopSlot 14 | from sapai.teams import Team, TeamSlot 15 | 16 | 17 | def storeaction(func): 18 | def store_action(*args, **kwargs): 19 | player = args[0] 20 | action_name = str(func.__name__).split(".")[-1] 21 | targets = func(*args, **kwargs) 22 | store_targets = [] 23 | if targets is not None: 24 | for entry in targets: 25 | if getattr(entry, "state", False): 26 | store_targets.append(entry.state) 27 | player.action_history.append((action_name, store_targets)) 28 | 29 | ### Make sure that the func returned as the same name as input func 30 | store_action.__name__ = func.__name__ 31 | 32 | return store_action 33 | 34 | 35 | class Player: 36 | """ 37 | Defines and implements all of the actions that a player can take. In particular 38 | each of these actions is directly tied to the actions reinforment learning 39 | models can take. 40 | 41 | Actions with the shop are based off of Objects and not based off of indices. 42 | There is a huge advantage to doing things this way. The index that a Pet/Food 43 | is in a shop is arbitrary. Therefore, when actions are based off the Object, 44 | The ML Agent will not have to learn the index invariances of the shop. 45 | 46 | The Player class is allowed to make appropriate changes to the Shop and 47 | Team. Therefore, Shops and Teams input into the Player class will not be 48 | static. The Player class is also responsible for checking all of the 49 | relevant Pet triggers when taking any action. 50 | 51 | """ 52 | 53 | def __init__( 54 | self, 55 | shop=None, 56 | team=None, 57 | lives=10, 58 | default_gold=10, 59 | gold=10, 60 | turn=1, 61 | lf_winner=None, 62 | action_history=None, 63 | pack="StandardPack", 64 | seed_state=None, 65 | wins=0, 66 | ): 67 | action_history = action_history or [] 68 | 69 | self.shop = shop 70 | self.team = team 71 | self.lives = lives 72 | self.default_gold = default_gold 73 | self.gold = gold 74 | self.pack = pack 75 | self.turn = turn 76 | self.wins = wins 77 | 78 | ### Default Parameters 79 | self._max_team = 5 80 | 81 | ### Keep track of outcome of last battle for snail 82 | self.lf_winner = lf_winner 83 | 84 | ### Initialize shop and team if not provided 85 | if self.shop is None: 86 | self.shop = Shop(pack=self.pack, seed_state=seed_state) 87 | if self.team is None: 88 | self.team = Team(seed_state=seed_state) 89 | 90 | if type(self.shop) == list: 91 | self.shop = Shop(self.shop, seed_state=seed_state) 92 | if type(self.team) == list: 93 | self.team = Team(self.team, seed_state=seed_state) 94 | 95 | ### Connect objects 96 | self.team.player = self 97 | for slot in self.team: 98 | slot._pet.player = self 99 | slot._pet.shop = self.shop 100 | 101 | for slot in self.shop: 102 | slot.obj.player = self 103 | slot.obj.shop = self.shop 104 | 105 | ### This stores the history of actions taken by the given player 106 | if len(action_history) == 0: 107 | self.action_history = [] 108 | else: 109 | self.action_history = list(action_history) 110 | 111 | @storeaction 112 | def start_turn(self, winner=None): 113 | ### Update turn count and gold 114 | self.turn += 1 115 | self.gold = self.default_gold 116 | self.lf_winner = winner 117 | 118 | ### Roll shop 119 | self.shop.turn += 1 120 | self.shop.roll() 121 | 122 | ### Activate start-of-turn triggers after rolling shop 123 | for slot in self.team: 124 | slot._pet.sot_trigger() 125 | 126 | return () 127 | 128 | @storeaction 129 | def buy_pet(self, pet): 130 | """Buy one pet from the shop""" 131 | if len(self.team) == self._max_team: 132 | raise Exception("Attempted to buy Pet on full team") 133 | 134 | if type(pet) == int: 135 | pet = self.shop[pet] 136 | 137 | if isinstance(pet, ShopSlot): 138 | pet = pet.obj 139 | 140 | if type(pet).__name__ != "Pet": 141 | raise Exception(f"Attempted to buy_pet using object {pet}") 142 | 143 | shop_idx = self.shop.index(pet) 144 | shop_slot = self.shop.slots[shop_idx] 145 | cost = shop_slot.cost 146 | 147 | if cost > self.gold: 148 | raise Exception( 149 | f"Attempted to buy Pet of cost {cost} with only {self.gold} gold" 150 | ) 151 | 152 | ### Connect pet with current Player 153 | pet.team = self.team 154 | pet.player = self 155 | pet.shop = self.shop 156 | 157 | ### Make all updates 158 | self.gold -= cost 159 | self.team.append(pet) 160 | self.shop.buy(pet) 161 | 162 | ### Check buy_friend triggers after purchase 163 | for slot in self.team: 164 | slot._pet.buy_friend_trigger(pet) 165 | 166 | ### Check summon triggers after purchse 167 | for slot in self.team: 168 | slot._pet.friend_summoned_trigger(pet) 169 | 170 | return (pet,) 171 | 172 | @storeaction 173 | def buy_food(self, food, team_pet=None): 174 | """ 175 | Buy and feed one food from the shop 176 | 177 | team_pet is either the purchase target or empty for food effect target 178 | 179 | """ 180 | if type(food) == int: 181 | food = self.shop[food] 182 | if food.slot_type != "food": 183 | raise Exception("Shop slot not food") 184 | if isinstance(food, ShopSlot): 185 | food = food.obj 186 | if type(food).__name__ != "Food": 187 | raise Exception(f"Attempted to buy_food using object {food}") 188 | 189 | if team_pet is None: 190 | targets, _ = get_target(food, [0, None], [self.team]) 191 | else: 192 | if type(team_pet) == int: 193 | team_pet = self.team[team_pet] 194 | if isinstance(team_pet, TeamSlot): 195 | team_pet = team_pet._pet 196 | if not self.team.check_friend(team_pet): 197 | raise Exception(f"Attempted to buy food for Pet not on team {team_pet}") 198 | if type(team_pet).__name__ != "Pet": 199 | raise Exception(f"Attempted to buy_pet using object {team_pet}") 200 | targets = [team_pet] 201 | 202 | shop_idx = self.shop.index(food) 203 | shop_slot = self.shop.slots[shop_idx] 204 | cost = shop_slot.cost 205 | 206 | if cost > self.gold: 207 | raise Exception( 208 | f"Attempted to buy Pet of cost {cost} with only {self.gold} gold" 209 | ) 210 | 211 | ### Before feeding, check for cat 212 | for slot in self.team: 213 | if slot._pet.name != "pet-cat": 214 | continue 215 | slot._pet.cat_trigger(food) 216 | 217 | ### Make all updates 218 | self.gold -= cost 219 | self.shop.buy(food) 220 | for pet in targets: 221 | levelup = pet.eat(food) 222 | ### Check for levelup triggers if appropriate 223 | if levelup: 224 | pet.levelup_trigger(pet) 225 | self.shop.levelup() 226 | 227 | ### After feeding, check for eats_shop_food triggers 228 | for slot in self.team: 229 | slot._pet.eats_shop_food_trigger(pet) 230 | 231 | ### After feeding, check for buy_food triggers 232 | for slot in self.team: 233 | slot._pet.buy_food_trigger() 234 | 235 | ### Check if any animals fainted because of pill and if any other 236 | ### animals fainted because of those animals fainting 237 | pp = Battle.update_pet_priority(self.team, Team()) # no enemy team in shop 238 | status_list = [] 239 | while True: 240 | ### Get a list of fainted pets 241 | fainted_list = [] 242 | for _, pet_idx in pp: 243 | p = self.team[pet_idx].pet 244 | if p.name == "pet-none": 245 | continue 246 | if p.health <= 0: 247 | fainted_list.append(pet_idx) 248 | if p.status != "none": 249 | status_list.append([p, pet_idx]) 250 | 251 | ### check every fainted pet 252 | faint_targets_list = [] 253 | for pet_idx in fainted_list: 254 | fainted_pet = self.team[pet_idx].pet 255 | ### check for all pets that trigger off this fainted pet (including self) 256 | for _, te_pet_idx in pp: 257 | other_pet = self.team[te_pet_idx].pet 258 | te_idx = [0, pet_idx] 259 | activated, targets, possible = other_pet.faint_trigger( 260 | fainted_pet, te_idx 261 | ) 262 | if activated: 263 | faint_targets_list.append( 264 | [fainted_pet, pet_idx, activated, targets, possible] 265 | ) 266 | 267 | ### If no trigger was activated, then the pet was never removed. 268 | ### Check to see if it should be removed now. 269 | if self.team.check_friend(fainted_pet): 270 | self.team.remove(fainted_pet) 271 | 272 | ### If pet was summoned, then need to check for summon triggers 273 | for ( 274 | fainted_pet, 275 | pet_idx, 276 | activated, 277 | targets, 278 | possible, 279 | ) in faint_targets_list: 280 | self.check_summon_triggers( 281 | fainted_pet, pet_idx, activated, targets, possible 282 | ) 283 | 284 | ### if pet was hurt, then need to check for hurt triggers 285 | hurt_list = [] 286 | for _, pet_idx in pp: 287 | p = self.team[pet_idx].pet 288 | while p._hurt > 0: 289 | hurt_list.append(pet_idx) 290 | activated, targets, possible = p.hurt_trigger(Team()) 291 | 292 | pp = Battle.update_pet_priority(self.team, Team()) 293 | 294 | ### if nothing happend, stop the loop 295 | if len(fainted_list) == 0 and len(hurt_list) == 0: 296 | break 297 | 298 | ### Check for status triggers on pet 299 | for p, pet_idx in status_list: 300 | self.check_status_triggers(p, pet_idx) 301 | 302 | return (food, targets) 303 | 304 | def check_summon_triggers(self, fainted_pet, pet_idx, activated, targets, possible): 305 | if not activated: 306 | return 307 | func = get_effect_function(fainted_pet) 308 | if func not in [RespawnPet, SummonPet, SummonRandomPet]: 309 | return 310 | for temp_te in targets: 311 | for temp_slot in self.team: 312 | temp_pet = temp_slot.pet 313 | temp_pet.friend_summoned_trigger(temp_te) 314 | 315 | def check_status_triggers(self, fainted_pet, pet_idx): 316 | if fainted_pet.status not in ["status-honey-bee", "status-extra-life"]: 317 | return 318 | 319 | ability = data["statuses"][fainted_pet.status]["ability"] 320 | fainted_pet.set_ability(ability) 321 | te_idx = [0, pet_idx] 322 | activated, targets, possible = fainted_pet.faint_trigger(fainted_pet, te_idx) 323 | self.check_summon_triggers(fainted_pet, pet_idx, activated, targets, possible) 324 | 325 | @storeaction 326 | def sell(self, pet): 327 | """Sell one pet on the team""" 328 | if type(pet) == int: 329 | pet = self.team[pet] 330 | 331 | if isinstance(pet, TeamSlot): 332 | pet = pet._pet 333 | 334 | if type(pet).__name__ != "Pet": 335 | raise Exception(f"Attempted to sell Object {pet}") 336 | 337 | ### Activate sell trigger first 338 | for slot in self.team: 339 | slot._pet.sell_trigger(pet) 340 | 341 | if self.team.check_friend(pet): 342 | self.team.remove(pet) 343 | 344 | ### Add default gold 345 | self.gold += 1 346 | 347 | return (pet,) 348 | 349 | @storeaction 350 | def sell_buy(self, team_pet, shop_pet): 351 | """Sell one team pet and replace it with one shop pet""" 352 | if type(shop_pet) == int: 353 | shop_pet = self.shop[shop_pet] 354 | if type(team_pet) == int: 355 | team_pet = self.team[team_pet] 356 | 357 | if isinstance(shop_pet, ShopSlot): 358 | shop_pet = shop_pet.obj 359 | if isinstance(team_pet, TeamSlot): 360 | team_pet = team_pet._pet 361 | 362 | if type(shop_pet).__name__ != "Pet": 363 | raise Exception(f"Attempted sell_buy with Shop item {shop_pet}") 364 | if type(team_pet).__name__ != "Pet": 365 | raise Exception(f"Attempted sell_buy with Team Pet {team_pet}") 366 | 367 | ### Activate sell trigger first 368 | self.sell(team_pet) 369 | 370 | ### Then attempt to buy shop pet 371 | self.buy_pet(shop_pet) 372 | 373 | return (team_pet, shop_pet) 374 | 375 | def freeze(self, obj): 376 | """Freeze one pet or food in the shop""" 377 | if isinstance(obj, ShopSlot): 378 | obj = obj.obj 379 | shop_idx = self.shop.index(obj) 380 | elif type(obj) == int: 381 | shop_idx = obj 382 | shop_slot = self.shop.slots[shop_idx] 383 | shop_slot.freeze() 384 | return (shop_slot,) 385 | 386 | def unfreeze(self, obj): 387 | """Unfreeze one pet or food in the shop""" 388 | if isinstance(obj, ShopSlot): 389 | obj = obj.obj 390 | shop_idx = self.shop.index(obj) 391 | elif type(obj) == int: 392 | shop_idx = obj 393 | shop_slot = self.shop.slots[shop_idx] 394 | shop_slot.unfreeze() 395 | return (shop_slot,) 396 | 397 | @storeaction 398 | def roll(self): 399 | """Roll shop""" 400 | if self.gold < 1: 401 | raise Exception("Attempt to roll without gold") 402 | self.shop.roll() 403 | self.gold -= 1 404 | return () 405 | 406 | @staticmethod 407 | def combine_pet_stats(pet_to_keep, pet_to_be_merged): 408 | """Pet 1 is the pet that is kept""" 409 | c_attack = max(pet_to_keep._attack, pet_to_be_merged._attack) + 1 410 | c_until_end_of_battle_attack = max( 411 | pet_to_keep._until_end_of_battle_attack_buff, 412 | pet_to_be_merged._until_end_of_battle_attack_buff, 413 | ) 414 | c_health = max(pet_to_keep._health, pet_to_be_merged._health) + 1 415 | c_until_end_of_battle_health = max( 416 | pet_to_keep._until_end_of_battle_health_buff, 417 | pet_to_be_merged._until_end_of_battle_health_buff, 418 | ) 419 | cstatus = get_combined_status(pet_to_keep, pet_to_be_merged) 420 | 421 | pet_to_keep._attack = c_attack 422 | pet_to_keep._health = c_health 423 | pet_to_keep._until_end_of_battle_attack_buff = c_until_end_of_battle_attack 424 | pet_to_keep._until_end_of_battle_health_buff = c_until_end_of_battle_health 425 | pet_to_keep.status = cstatus 426 | levelup = pet_to_keep.gain_experience() 427 | 428 | # Check for levelup triggers if appropriate 429 | if levelup: 430 | # Activate the ability of the previous level 431 | pet_to_keep.level -= 1 432 | pet_to_keep.levelup_trigger(pet_to_keep) 433 | pet_to_keep.level += 1 434 | 435 | return levelup 436 | 437 | @storeaction 438 | def buy_combine(self, shop_pet, team_pet): 439 | """Combine two pets on purchase""" 440 | if type(shop_pet) == int: 441 | shop_pet = self.shop[shop_pet] 442 | if type(team_pet) == int: 443 | team_pet = self.team[team_pet] 444 | 445 | if isinstance(shop_pet, ShopSlot): 446 | shop_pet = shop_pet.obj 447 | if isinstance(team_pet, TeamSlot): 448 | team_pet = team_pet._pet 449 | 450 | if type(shop_pet).__name__ != "Pet": 451 | raise Exception(f"Attempted buy_combined with Shop item {shop_pet}") 452 | if type(team_pet).__name__ != "Pet": 453 | raise Exception(f"Attempted buy_combined with Team Pet {team_pet}") 454 | if team_pet.name != shop_pet.name: 455 | raise Exception( 456 | f"Attempted combine for pets {team_pet.name} and {shop_pet.name}" 457 | ) 458 | 459 | shop_idx = self.shop.index(shop_pet) 460 | shop_slot = self.shop.slots[shop_idx] 461 | cost = shop_slot.cost 462 | 463 | if cost > self.gold: 464 | raise Exception( 465 | f"Attempted to buy Pet of cost {cost} with only {self.gold} gold" 466 | ) 467 | 468 | ### Make all updates 469 | self.gold -= cost 470 | self.shop.buy(shop_pet) 471 | 472 | levelup = Player.combine_pet_stats(team_pet, shop_pet) 473 | if levelup: 474 | self.shop.levelup() 475 | 476 | ### Check for buy_pet triggers 477 | for slot in self.team: 478 | slot._pet.buy_friend_trigger(team_pet) 479 | 480 | return shop_pet, team_pet 481 | 482 | @storeaction 483 | def combine(self, pet1, pet2): 484 | """Combine two pets on the team together""" 485 | if type(pet1) == int: 486 | pet1 = self.team[pet1] 487 | if type(pet2) == int: 488 | pet2 = self.team[pet2] 489 | 490 | if isinstance(pet1, TeamSlot): 491 | pet1 = pet1._pet 492 | if isinstance(pet2, TeamSlot): 493 | pet2 = pet2._pet 494 | 495 | if not self.team.check_friend(pet1): 496 | raise Exception(f"Attempted combine for Pet not on team {pet1}") 497 | if not self.team.check_friend(pet2): 498 | raise Exception(f"Attempted combine for Pet not on team {pet2}") 499 | 500 | if pet1.name != pet2.name: 501 | raise Exception(f"Attempted combine for pets {pet1.name} and {pet2.name}") 502 | 503 | levelup = Player.combine_pet_stats(pet1, pet2) 504 | if levelup: 505 | self.shop.levelup() 506 | 507 | ### Remove pet2 from team 508 | idx = self.team.index(pet2) 509 | self.team[idx] = TeamSlot() 510 | 511 | return pet1, pet2 512 | 513 | @storeaction 514 | def reorder(self, idx): 515 | """Reorder team""" 516 | if len(idx) != len(self.team): 517 | raise Exception("Reorder idx must match team length") 518 | unique = np.unique(idx) 519 | 520 | if len(unique) != len(self.team): 521 | raise Exception(f"Cannot input duplicate indices to reorder: {idx}") 522 | 523 | self.team = Team([self.team[x] for x in idx], seed_state=self.team.seed_state) 524 | 525 | return idx 526 | 527 | @storeaction 528 | def end_turn(self): 529 | """End turn and move to battle phase""" 530 | ### Activate eot trigger 531 | for slot in self.team: 532 | slot._pet.eot_trigger() 533 | return None 534 | 535 | @property 536 | def state(self): 537 | state_dict = { 538 | "type": "Player", 539 | "team": self.team.state, 540 | "shop": self.shop.state, 541 | "lives": self.lives, 542 | "default_gold": self.default_gold, 543 | "gold": self.gold, 544 | "lf_winner": self.lf_winner, 545 | "pack": self.pack, 546 | "turn": self.turn, 547 | "action_history": self.action_history, 548 | "wins": self.wins, 549 | } 550 | return state_dict 551 | 552 | @classmethod 553 | def from_state(cls, state): 554 | team = Team.from_state(state["team"]) 555 | shop_type = state["shop"]["type"] 556 | shop_cls = getattr(sapai.shop, shop_type) 557 | shop = shop_cls.from_state(state["shop"]) 558 | if "action_history" in state: 559 | action_history = state["action_history"] 560 | else: 561 | action_history = [] 562 | return cls( 563 | team=team, 564 | shop=shop, 565 | lives=state["lives"], 566 | default_gold=state["default_gold"], 567 | gold=state["gold"], 568 | turn=state["turn"], 569 | lf_winner=state["lf_winner"], 570 | pack=state["pack"], 571 | action_history=action_history, 572 | wins=state["wins"], 573 | ) 574 | 575 | def __repr__(self): 576 | info_str = f"PACK: {self.pack}\n" 577 | info_str += f"TURN: {self.turn}\n" 578 | info_str += f"LIVES: {self.lives}\n" 579 | info_str += f"WINS: {self.wins}\n" 580 | info_str += f"GOLD: {self.gold}\n" 581 | print_str = "--------------\n" 582 | print_str += "CURRENT INFO: \n--------------\n" + info_str + "\n" 583 | print_str += "CURRENT TEAM: \n--------------\n" + self.team.__repr__() + "\n" 584 | print_str += "CURRENT SHOP: \n--------------\n" + self.shop.__repr__() 585 | return print_str 586 | 587 | 588 | def get_combined_status(pet1, pet2): 589 | """ 590 | Statuses are combined based on the tier that they come from. 591 | 592 | """ 593 | status_tier = { 594 | 0: ["status-weak", "status-poison-attack", "none"], 595 | 1: ["status-honey-bee"], 596 | 2: ["status-bone-attack"], 597 | 3: ["status-garlic-armor"], 598 | 4: ["status-splash-attack"], 599 | 5: [ 600 | "status-coconut-shield", 601 | "status-melon-armor", 602 | "status-steak-attack", 603 | "status-extra-life", 604 | ], 605 | } 606 | 607 | status_lookup = {} 608 | for key, value in status_tier.items(): 609 | for entry in value: 610 | status_lookup[entry] = key 611 | 612 | ### If there is a tie in tier, then pet1 status is used 613 | max_idx = np.argmax([status_lookup[pet1.status], status_lookup[pet2.status]]) 614 | 615 | return [pet1.status, pet2.status][max_idx] 616 | -------------------------------------------------------------------------------- /sapai/agents.py: -------------------------------------------------------------------------------- 1 | # %% 2 | 3 | import os, json, itertools, torch 4 | import numpy as np 5 | from sapai import Player 6 | from sapai.battle import Battle 7 | from sapai.compress import compress, decompress 8 | 9 | ### Pets with a random component 10 | ### Random component in the future should just be handled in an exact way 11 | ### whereby all possible outcomes are evaluated just once. This would 12 | ### significantly speed up training. 13 | random_buy_pets = {"pet-otter"} 14 | random_sell_pets = {"pet-beaver"} 15 | random_pill_pets = {"pet-ant"} 16 | random_battle_pets = {"pet-mosquito"} 17 | 18 | 19 | class CombinatorialSearch: 20 | """ 21 | CombinatorialSearch is a method to enumerate the entire possible search 22 | space for the current shopping phase. The search starts from an initial 23 | player state provided in the arguments. Then all possible next actions 24 | are taken from that player state until the Player's gold is exhausted. 25 | Returned is a list of all final possible final player states after the 26 | shopping phase. 27 | 28 | This algorithm can be parallelized in the sense that after round 1, there 29 | will be a large number of player states. Each of these player states can 30 | then be fed back into the CombinatorialSearch individually to search for 31 | their round 2 possibilities. Therefore, parallelization can occur across 32 | the possible player states for each given round. This parallelization 33 | should take place outside of this Class. 34 | 35 | Parallelization within this Class itself would be a bit more difficult and 36 | would only be advantageous to improve the search speed when condering a 37 | single team. This would be better for run-time evaluation of models, but 38 | it unnecessary for larger searchers, which are the present focus. Therefore, 39 | this will be left until later. 40 | 41 | A CombinatorialAgent can be built on top of this to resemble the way that a 42 | human player will make decisions. However, the CombinatorialAgent will 43 | consider all possible combinations of decisions in order arrive at the 44 | best possible next decisions given the available gold. 45 | 46 | Arguments 47 | --------- 48 | verbose: bool 49 | If True, messages are printed during the search 50 | max_actions: int 51 | Maximum depth, equal to the number of shop actions, that can be performed 52 | during search space enumeration. Using max_actions of 1 would correspond 53 | to a greedy search algorithm in conjunction with any Agents. 54 | 55 | """ 56 | 57 | def __init__(self, verbose=True, max_actions=-1): 58 | self.verbose = verbose 59 | self.max_actions = max_actions 60 | 61 | ### This stores the player lists for performing all possible actions 62 | self.player_list = [] 63 | 64 | ### Player dict stores compressed str of all players such that if the 65 | ### same player state will never be used twice 66 | self.player_state_dict = {} 67 | 68 | self.current_print_number = 0 69 | 70 | def avail_actions(self, player): 71 | """ 72 | Return all possible available actions 73 | 74 | """ 75 | action_list = [()] 76 | ### Include only the actions that cost gold 77 | action_list += self.avail_buy_pets(player) 78 | action_list += self.avail_buy_food(player) 79 | action_list += self.avail_buy_combine(player) 80 | action_list += self.avail_team_combine(player) 81 | action_list += self.avail_sell(player) 82 | action_list += self.avail_sell_buy(player) 83 | action_list += self.avail_roll(player) 84 | return action_list 85 | 86 | def avail_buy_pets(self, player): 87 | """ 88 | Returns all possible pets that can be bought from the player's shop 89 | """ 90 | action_list = [] 91 | gold = player.gold 92 | if len(player.team) == 5: 93 | ### Cannot buy for full team 94 | return action_list 95 | for shop_idx in player.shop.filled: 96 | shop_slot = player.shop[shop_idx] 97 | if shop_slot.slot_type == "pet": 98 | if shop_slot.cost <= gold: 99 | action_list.append((player.buy_pet, shop_idx)) 100 | return action_list 101 | 102 | def avail_buy_food(self, player): 103 | """ 104 | Returns all possible food that can be bought from the player's shop 105 | """ 106 | action_list = [] 107 | gold = player.gold 108 | if len(player.team) == 0: 109 | return action_list 110 | for shop_idx in player.shop.filled: 111 | shop_slot = player.shop[shop_idx] 112 | if shop_slot.slot_type == "food": 113 | if shop_slot.cost <= gold: 114 | for team_idx, team_slot in enumerate(player.team): 115 | if team_slot.empty: 116 | continue 117 | action_list.append((player.buy_food, shop_idx, team_idx)) 118 | return action_list 119 | 120 | def avail_buy_combine(self, player): 121 | action_list = [] 122 | gold = player.gold 123 | team_names = {} 124 | for team_idx, slot in enumerate(player.team): 125 | if slot.empty: 126 | continue 127 | if slot.pet.name not in team_names: 128 | team_names[slot.pet.name] = [] 129 | team_names[slot.pet.name].append(team_idx) 130 | if len(player.team) == 0: 131 | return action_list 132 | for shop_idx in player.shop.filled: 133 | shop_slot = player.shop[shop_idx] 134 | if shop_slot.slot_type == "pet": 135 | ### Can't combine if pet not already on team 136 | if shop_slot.obj.name not in team_names: 137 | continue 138 | 139 | if shop_slot.cost <= gold: 140 | for team_idx in team_names[shop_slot.obj.name]: 141 | action_list.append((player.buy_combine, shop_idx, team_idx)) 142 | return action_list 143 | 144 | def avail_team_combine(self, player): 145 | action_list = [] 146 | if len(player.team) == 0: 147 | return action_list 148 | 149 | team_names = {} 150 | for slot_idx, slot in enumerate(player.team): 151 | if slot.empty: 152 | continue 153 | if slot.pet.name not in team_names: 154 | team_names[slot.pet.name] = [] 155 | team_names[slot.pet.name].append(slot_idx) 156 | 157 | for key, value in team_names.items(): 158 | if len(value) == 1: 159 | continue 160 | 161 | for idx0, idx1 in itertools.combinations(value, r=2): 162 | action_list.append((player.combine, idx0, idx1)) 163 | 164 | return action_list 165 | 166 | def avail_sell(self, player): 167 | action_list = [] 168 | if len(player.team) <= 1: 169 | ### Not able to sell the final friend on team 170 | return action_list 171 | for team_idx, slot in enumerate(player.team): 172 | if slot.empty: 173 | continue 174 | action_list.append((player.sell, team_idx)) 175 | return action_list 176 | 177 | def avail_sell_buy(self, player): 178 | """ 179 | Sell buy should only be used if the team is full. This is done so that 180 | the search space is not increased unnecessarily. However, a full agent 181 | implementation can certainly consider this action at any point in the 182 | game as long as there are pets to sell and buy 183 | """ 184 | action_list = [] 185 | gold = player.gold 186 | if len(player.team) != 5: 187 | return action_list 188 | team_idx_list = player.team.get_fidx() 189 | shop_idx_list = [] 190 | for shop_idx in player.shop.filled: 191 | shop_slot = player.shop[shop_idx] 192 | if shop_slot.slot_type == "pet": 193 | if shop_slot.cost <= gold: 194 | shop_idx_list.append(shop_idx) 195 | 196 | prod = itertools.product(team_idx_list, shop_idx_list) 197 | for temp_team_idx, temp_shop_idx in prod: 198 | action_list.append((player.sell_buy, temp_team_idx, temp_shop_idx)) 199 | 200 | return action_list 201 | 202 | def avail_team_order(self, player): 203 | """Returns all possible orderings for the team""" 204 | action_list = [] 205 | 206 | team_range = np.arange(0, len(player.team)) 207 | if len(team_range) == 0: 208 | return [] 209 | 210 | for order in itertools.permutations(team_range, r=len(team_range)): 211 | action_list.append((player.reorder, order)) 212 | 213 | return action_list 214 | 215 | def avail_roll(self, player): 216 | action_list = [] 217 | ##### ASSUMPTION: If gold is not 1, do not roll for now because with 218 | ##### ShopLearn, rolling has no meaning 219 | if player.gold != 1: 220 | return action_list 221 | if player.gold > 0: 222 | action_list.append((player.roll,)) 223 | return action_list 224 | 225 | def search(self, player, player_state_dict=None): 226 | ### Initialize internal storage 227 | self.player = player 228 | self.player_list = [] 229 | self.player_state = self.player.state 230 | self.current_print_number = 0 231 | self.print_message("start", self.player) 232 | 233 | ### build_player_list searches shop actions and returns player list 234 | ### In addition, it builds a player_state_dict which can be used for 235 | ### faster lookup of redundant player states 236 | if player_state_dict is None: 237 | self.player_state_dict = {} 238 | else: 239 | self.player_state_dict = player_state_dict 240 | self.player_list = self.build_player_list(self.player) 241 | self.print_message("player_list_done", self.player_list) 242 | 243 | ### Now consider all possible reorderings of team 244 | self.player_list, self.player_state_dict = self.search_reordering( 245 | self.player_list, self.player_state_dict 246 | ) 247 | 248 | ### End turn for all in player list 249 | for temp_player in self.player_list: 250 | temp_player.end_turn() 251 | ### NOTE: After end_turn the player_state_dict has not been updated, 252 | ### therefore, the player_state_dict is no longer reliable and should 253 | ### NOT be used outside of this Class. If the player_state_dict is 254 | ### required, it should be rebuilt from the player_list itself 255 | 256 | ### Also, return only the unique team list for convenience 257 | self.team_dict = self.get_team_dict(self.player_list) 258 | 259 | self.print_message("done", (self.player_list, self.team_dict)) 260 | 261 | return self.player_list, self.team_dict 262 | 263 | def build_player_list(self, player, player_list=None): 264 | """ 265 | Recursive function for building player list for a given turn using all 266 | actions during the shopping phase 267 | 268 | """ 269 | if player.gold <= 0: 270 | ### If gold is 0, then this is exit condition for the 271 | ### recursive function 272 | return [] 273 | if player_list is None: 274 | player_list = [] 275 | 276 | player_state = player.state 277 | self.print_message("size", self.player_state_dict) 278 | if self.max_actions > 0: 279 | actions_taken = len(player.action_history) 280 | if actions_taken >= self.max_actions: 281 | return [] 282 | 283 | avail_actions = self.avail_actions(player) 284 | for temp_action in avail_actions: 285 | if temp_action == (): 286 | ### Null action 287 | continue 288 | 289 | #### Re-initialize Player 290 | temp_player = Player.from_state(player_state) 291 | 292 | #### Perform action 293 | action_name = str(temp_action[0].__name__).split(".")[-1] 294 | action = getattr(temp_player, action_name) 295 | action(*temp_action[1:]) 296 | 297 | ### Check if this is unique player state 298 | temp_player.team.move_forward() ### Move team forward so that 299 | ### team is index invariant 300 | 301 | ### Don't need history in order to check for redundancy of the 302 | ### shop state. This means that it does not matter how a Shop 303 | ### gets to a state, just that the state is identical to others. 304 | cstate = compress(temp_player, minimal=True) 305 | # cstate = hash(json.dumps(temp_player.state)) 306 | if cstate not in self.player_state_dict: 307 | self.player_state_dict[cstate] = temp_player 308 | else: 309 | ### If player state has been seen before, then do not append 310 | ### to the player list. 311 | continue 312 | 313 | player_list.append(temp_player) 314 | 315 | full_player_list = player_list 316 | for player in player_list: 317 | ### Now, call this function recurisvely to add the next action 318 | temp_player_list = [] 319 | self.build_player_list(player, temp_player_list) 320 | full_player_list += temp_player_list 321 | 322 | return full_player_list 323 | 324 | def search_reordering(self, player_list, player_state_dict): 325 | """ 326 | Searches over all possible unique reorderings of the teams 327 | 328 | """ 329 | additional_player_list = [] 330 | for player in player_list: 331 | player_state = player.state 332 | reorder_actions = self.avail_team_order(player) 333 | for temp_action in reorder_actions: 334 | if temp_action == (): 335 | ### Null action 336 | continue 337 | 338 | #### Re-initialize identical Player 339 | temp_player = Player.from_state(player_state) 340 | 341 | #### Perform action 342 | action_name = str(temp_action[0].__name__).split(".")[-1] 343 | action = getattr(temp_player, action_name) 344 | action(*temp_action[1:]) 345 | 346 | ### Check if this is unique player state 347 | temp_player.team.move_forward() ### Move team forward so that 348 | ### team is index invariant 349 | 350 | ### Don't need history in order to check for redundancy of the 351 | ### shop state. This means that it does not matter how a Shop 352 | ### gets to a state, just that the state is identical to others. 353 | cstate = compress(temp_player, minimal=True) 354 | # cstate = hash(json.dumps(temp_player.state)) 355 | if cstate not in player_state_dict: 356 | player_state_dict[cstate] = temp_player 357 | else: 358 | ### If player state has been seen before, then do not append 359 | ### to the player list. 360 | continue 361 | 362 | additional_player_list.append(temp_player) 363 | 364 | ##### METHOD SHOULD BE USED THAT DOESN'T REQUIRE Player.from_state 365 | ##### This would save a lot of time 366 | # ### Move team back into place 367 | # order_idx = temp_action[1] 368 | # reorder_idx = np.argsort(order_idx).tolist() 369 | # action(reorder_idx) 370 | # ### Delete last two actions to reset player 371 | # del(temp_player.action_history[-1]) 372 | # del(temp_player.action_history[-1]) 373 | 374 | player_list += additional_player_list 375 | return player_list, player_state_dict 376 | 377 | def get_team_dict(self, player_list): 378 | """ 379 | Returns dictionary of only the unique teams 380 | 381 | """ 382 | team_dict = {} 383 | for player in player_list: 384 | team = player.team 385 | ### Move forward to make team index invariant 386 | team.move_forward() 387 | cteam = compress(team, minimal=True) 388 | ### Can just always do like this, don't need to check if it's 389 | ### already in dictionary because it can just be overwritten 390 | team_dict[cteam] = team 391 | return team_dict 392 | 393 | def print_message(self, message_type, info): 394 | if not self.verbose: 395 | return 396 | 397 | if message_type not in ["start", "size", "player_list_done", "done"]: 398 | raise Exception(f"Unrecognized message type {message_type}") 399 | 400 | if message_type == "start": 401 | print("---------------------------------------------------------") 402 | print("STARTING SEARCH WITH INITIAL PLAYER: ") 403 | print(info) 404 | 405 | print("---------------------------------------------------------") 406 | print("STARTING TO BUILD PLAYER LIST") 407 | 408 | elif message_type == "size": 409 | temp_size = len(info) 410 | if temp_size < (self.current_print_number + 100): 411 | return 412 | print(f"RUNNING MESSAGE: Current Number of Unique Players is {len(info)}") 413 | self.current_print_number = temp_size 414 | 415 | elif message_type == "player_list_done": 416 | print("---------------------------------------------------------") 417 | print("DONE BUILDING PLAYER LIST") 418 | print(f"NUMBER OF PLAYERS IN PLAYER LIST: {len(info)}") 419 | 420 | print("BEGINNING TO SEARCH FOR ALL POSSIBLE TEAM ORDERS") 421 | 422 | elif message_type == "done": 423 | print("---------------------------------------------------------") 424 | print("DONE WITH CombinatorialSearch") 425 | print(f"NUMBER OF PLAYERS IN PLAYER LIST: {len(info[0])}") 426 | print(f"NUMBER OF UNIQUE TEAMS: {len(info[1])}") 427 | 428 | 429 | class DatabaseLookupRanker: 430 | """ 431 | Will provide a rank to a given team based on its performance on a database 432 | of teams. 433 | 434 | """ 435 | 436 | def __init__( 437 | self, 438 | path="", 439 | ): 440 | self.path = path 441 | 442 | if os.path.exists(path): 443 | with open(path) as f: 444 | self.database = json.loads(f) 445 | else: 446 | self.database = {} 447 | 448 | self.team_database = {} 449 | for key, value in self.database: 450 | self.team_database[key] = { 451 | "team": decompress(key), 452 | "wins": int(len(self.database) * value), 453 | "total": len(self.database) - 1, 454 | } 455 | 456 | def __call__(self, team): 457 | c = compress(team) 458 | if c in self.database: 459 | return self.database[c] 460 | else: 461 | return self.run_against_database(team) 462 | 463 | def run_against_database(self, team): 464 | #### Add team to database 465 | team_key = compress(team, minimal=True) 466 | if team_key not in self.team_database: 467 | self.team_database[team_key] = {"team": team, "wins": 0, "total": 0} 468 | 469 | for key, value in self.team_database.items(): 470 | # print(team, value["team"]) 471 | self.t0 = team 472 | self.t1 = value["team"] 473 | 474 | f = Battle(team, value["team"]) 475 | winner = f.battle() 476 | 477 | winner_key = [[team_key], [key], []][winner] 478 | for temp_key in winner_key: 479 | self.team_database[temp_key]["wins"] += 1 480 | for temp_key in [team_key, key]: 481 | self.team_database[temp_key]["total"] += 1 482 | 483 | for key, value in self.team_database.items(): 484 | wins = self.team_database[key]["wins"] 485 | total = self.team_database[key]["total"] 486 | self.database[key] = wins / total 487 | 488 | wins = self.team_database[team_key]["wins"] 489 | total = self.team_database[team_key]["total"] 490 | return wins / total 491 | 492 | def test_against_database(self, team): 493 | wins = 0 494 | total = 0 495 | for key, value in self.team_database.items(): 496 | # print(team, value["team"]) 497 | f = Battle(team, value["team"]) 498 | winner = f.battle() 499 | if winner == 0: 500 | wins += 1 501 | total += 1 502 | return wins, total 503 | 504 | 505 | class PairwiseBattles: 506 | """ 507 | Parallel function using MPI for calculation Pairwise battles. 508 | 509 | Disadvantage of current method is that not check-pointing is done in the 510 | calculation. This means that the results will be written as all or nothing. 511 | If the calculation is interrupted before finishing, than all results will 512 | be lost. This is a common issue of simple parallelization... 513 | 514 | """ 515 | 516 | def __init__(self, output="results.pt"): 517 | try: 518 | from mpi4py import MPI 519 | 520 | parallel_check = True 521 | except: 522 | parallel_check = False 523 | 524 | if not parallel_check: 525 | raise Exception("MPI parallelization not available") 526 | 527 | self.comm = MPI.COMM_WORLD 528 | self.size = self.comm.Get_size() 529 | self.rank = self.comm.Get_rank() 530 | self.output = output 531 | 532 | def battle(self, obj): 533 | ### Prepare job-list on rank 0 534 | if self.rank == 0: 535 | team_list = [] 536 | if type(obj) == dict: 537 | team_list += list(obj.values()) 538 | else: 539 | team_list += list(obj) 540 | 541 | if type(team_list[0]).__name__ != "Team": 542 | raise Exception("Input object is not Team Dict or Team List") 543 | 544 | print("------------------------------------", flush=True) 545 | print("RUNNING PAIRWISE BATTLES", flush=True) 546 | print(f"{'NUM':16s}: NUMBER", flush=True) 547 | print(f"{'NUM RANKS':16s}: {self.size}", flush=True) 548 | print(f"{'INPUT TEAMS':16s}: {len(obj)}", flush=True) 549 | 550 | ### Easier for indexing 551 | team_array = np.zeros((len(team_list),), dtype=object) 552 | team_array[:] = team_list[:] 553 | pair_idx = self._get_pair_idx(team_list) 554 | print(f"{'NUMBER BATTLES':16s}: {len(pair_idx)}", flush=True) 555 | ### Should I send just index and read in files on all ranks... 556 | ### or should Teams be sent to ranks... 557 | ### Well, I don't think this function will every have >2 GB sized 558 | ### team dataset anyways... 559 | for temp_rank in np.arange(1, self.size): 560 | temp_idx = pair_idx[temp_rank :: self.size] 561 | temp_teams = np.take(team_array, temp_idx) 562 | self.comm.send((temp_idx, temp_teams), temp_rank) 563 | 564 | my_idx = pair_idx[0 :: self.size] 565 | my_teams = np.take(team_array, my_idx) 566 | print(f"{'BATTLES PER RANK':16s}: {len(my_teams)}", flush=True) 567 | 568 | else: 569 | ### Wait for info from rank 0 570 | my_idx, my_teams = self.comm.recv(source=0) 571 | 572 | if self.rank != 0: 573 | winner_list = [] 574 | iter_idx = 0 575 | for t0, t1 in my_teams: 576 | b = Battle(t0, t1) 577 | temp_winner = b.battle() 578 | winner_list.append(temp_winner) 579 | iter_idx += 1 580 | else: 581 | #### This is split to remove branching code in the for loop above 582 | winner_list = [] 583 | iter_idx = 0 584 | for t0, t1 in my_teams: 585 | b = Battle(t0, t1) 586 | temp_winner = b.battle() 587 | winner_list.append(temp_winner) 588 | iter_idx += 1 589 | if iter_idx % 1000 == 0: 590 | print( 591 | f"{'FINISHED':16s}: {iter_idx * self.size} of {len(pair_idx)}", 592 | flush=True, 593 | ) 594 | 595 | winner_list = np.array(winner_list).astype(int) 596 | 597 | ### Send results back to rank 0 598 | self.comm.barrier() 599 | 600 | if self.rank == 0: 601 | print("------------------------------------", flush=True) 602 | print("DONE CALCULATING BATTLES", flush=True) 603 | ### Using +1 so that the last entry in the array can be used as 604 | ### as throw-away when draws occur 605 | wins = np.zeros((len(team_array) + 1,)).astype(int) 606 | total = np.zeros((len(team_array) + 1,)).astype(int) 607 | ### Add info from rank 0 608 | add_totals_idx, add_totals = np.unique(my_idx[:, 0], return_counts=True) 609 | total[add_totals_idx] += add_totals 610 | add_totals_idx, add_totals = np.unique(my_idx[:, 1], return_counts=True) 611 | total[add_totals_idx] += add_totals 612 | ### Use for fast indexing for counting up wins 613 | temp_draw_idx = np.zeros((len(my_idx),)) - 1 614 | winner_idx_mask = np.hstack([my_idx, temp_draw_idx[:, None]]) 615 | winner_idx_mask = winner_idx_mask.astype(int) 616 | winner_idx = winner_idx_mask[np.arange(0, len(winner_list)), winner_list] 617 | winner_idx, win_count = np.unique(winner_idx, return_counts=True) 618 | wins[winner_idx] += win_count 619 | 620 | for temp_rank in np.arange(1, self.size): 621 | temp_idx, temp_winner_list = self.comm.recv(source=temp_rank) 622 | add_totals_idx, add_totals = np.unique( 623 | temp_idx[:, 0], return_counts=True 624 | ) 625 | total[add_totals_idx] += add_totals 626 | add_totals_idx, add_totals = np.unique( 627 | temp_idx[:, 1], return_counts=True 628 | ) 629 | total[add_totals_idx] += add_totals 630 | ### Use for fast indexing for counting up wins 631 | temp_draw_idx = np.zeros((len(temp_idx),)) - 1 632 | winner_idx_mask = np.hstack([temp_idx, temp_draw_idx[:, None]]) 633 | winner_idx_mask = winner_idx_mask.astype(int) 634 | winner_idx = winner_idx_mask[ 635 | np.arange(0, len(temp_winner_list)), temp_winner_list 636 | ] 637 | winner_idx, win_count = np.unique(winner_idx, return_counts=True) 638 | wins[winner_idx] += win_count 639 | 640 | ### Throw away last entry for ties 641 | wins = wins[0:-1] 642 | total = total[0:-1] 643 | frac = wins / total 644 | 645 | results = {} 646 | for iter_idx, temp_team in enumerate(team_list): 647 | temp_frac = frac[iter_idx] 648 | results[compress(temp_team, minimal=True)] = temp_frac 649 | 650 | print(f"WRITING OUTPUTS AT: {self.output}", flush=True) 651 | torch.save(results, self.output) 652 | 653 | print("------------------------------------", flush=True) 654 | print("COMPLETED", flush=True) 655 | else: 656 | self.comm.send((my_idx, winner_list), 0) 657 | 658 | ### Barrier before exiting 659 | self.comm.barrier() 660 | return 661 | 662 | def _get_pair_idx(self, team_list): 663 | """ 664 | Get the dictionary of pair_dict that have to be made for pair mode 665 | esxecution. 666 | 667 | """ 668 | idx = np.triu_indices(n=len(team_list), k=1, m=len(team_list)) 669 | return np.array([x for x in zip(idx[0], idx[1])]) 670 | 671 | 672 | # %% 673 | -------------------------------------------------------------------------------- /sapai/battle.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sapai.data import data 4 | from sapai.effects import ( 5 | get_effect_function, 6 | get_teams, 7 | RespawnPet, 8 | SummonPet, 9 | SummonRandomPet, 10 | ) 11 | from sapai import status 12 | 13 | 14 | class Battle: 15 | """ 16 | Performs a battle. 17 | 18 | Most important thing here to implement is the action queue including the 19 | logic for when actions should be removed from the action queue upon death. 20 | 21 | Note that effects are performed in the order of highest attack to lowest 22 | attack. If there is a tie, then health values are compared. If there is a 23 | tie then a random animal is chosen first. This is tracked by the 24 | pet_priority which is updated before every turn of the battle. 25 | 26 | Any effect which is in the queue for a given turn is executed, even if the 27 | animal dies due to preceeding effect, as the game entails. 28 | 29 | A Battle goes as follows: 30 | 1. execute start-of-turn abilities according to pet priority 31 | 2. perform hurt and faint abilities according to pet priority 32 | 2.1 Execute 2 until there are no new fainted animals 33 | 3. before-attack abilities according to pet priority 34 | 4. perform fainted pet abilities via pet priority 35 | 4.1 Execute 4 until there are no new fainted animals 36 | 5. attack phase 37 | 5.0 perform before_attack abilities 38 | 5.1. perform hurt and fainted abilities according to pet priority 39 | 5.1.1 Execute 5.1 until there are no new fainted animals 40 | 5.2 perform attack damage 41 | 5.3 perform after attack abilities 42 | 5.4 perform hurt and fainted abilities according to pet priority 43 | 5.4.1 Execute 5.4 until there are no new fainted animals 44 | 5.5. check if knock-out abilities should be performed 45 | 5.5.1 if knock-out ability activated jump to 5.5 46 | 5.6. if battle has not ended, jump to 5.0 47 | 48 | """ 49 | 50 | def __init__(self, t0, t1): 51 | """ 52 | Performs the battle between the input teams t1 and t2. 53 | 54 | """ 55 | ### Make copy each team to cary out the battle so that the original 56 | ### pets are not modified in any way after the battle 57 | self.t0 = t0.copy() 58 | self.t0._battle = True 59 | self.t1 = t1.copy() 60 | self.t1._battle = True 61 | 62 | ### Internal storage 63 | self.pet_priority = [] 64 | self.battle_history = {} 65 | 66 | ### Build initial effect queue order 67 | self.pet_priority = self.update_pet_priority(self.t0, self.t1) 68 | 69 | def battle(self): 70 | ### Perform all effects that occur at the start of the battle 71 | self.start() 72 | 73 | battle_iter = 0 74 | while True: 75 | ### First update effect order 76 | self.pet_priority = self.update_pet_priority(self.t0, self.t1) 77 | ### Then attack 78 | result = self.attack(battle_iter) 79 | battle_iter += 1 80 | if not result: 81 | break 82 | 83 | ### Check winner and return 0 for t0 win, 1 for t1 win, 2 for draw 84 | return self.check_battle_result() 85 | 86 | def start(self): 87 | """ 88 | Perform all start of battle effects 89 | 90 | """ 91 | ### First move the teams forward 92 | t0 = self.t0 93 | t1 = self.t1 94 | teams = [t0, t1] 95 | 96 | ### Phase of start 97 | phase_dict = { 98 | "init": [[str(x) for x in t0], [str(x) for x in t1]], 99 | "start": { 100 | "phase_move_start": [], 101 | "phase_start": [], 102 | "phase_hurt_and_faint": [], 103 | "phase_move_end": [], 104 | }, 105 | } 106 | 107 | for temp_phase in phase_dict["start"]: 108 | battle_phase( 109 | self, temp_phase, teams, self.pet_priority, phase_dict["start"] 110 | ) 111 | 112 | self.battle_history.update(phase_dict) 113 | 114 | ### If animals have moved or fainted then effect order must be updated 115 | if temp_phase.startswith("phase_move"): 116 | self.pet_priority = self.update_pet_priority(t0, t1) 117 | 118 | def attack(self, battle_iter): 119 | """ 120 | Perform and attack and then check for new pet triggers 121 | 122 | Returns whether or not another attack should occur. This depends on 123 | if all animals of one team have a health of 0 already. 124 | 125 | Order of operations for an attack are: 126 | - Pets in the front of each team attack 127 | - Apply effects related to this attack 128 | - Apply effects related to pet deaths 129 | - Summon phase 130 | - Check if battle is over 131 | 132 | """ 133 | t0 = self.t0 134 | t1 = self.t1 135 | 136 | attack_str = f"attack {battle_iter}" 137 | phase_dict = { 138 | attack_str: { 139 | "phase_move_start": [], 140 | "phase_attack_before": [], 141 | "phase_hurt_and_faint_ab": [], 142 | "phase_attack": [], 143 | "phase_attack_after": [], 144 | "phase_hurt_and_faint_aa": [], 145 | "phase_knockout": [], 146 | "phase_hurt_and_faint_k": [], 147 | "phase_move_end": [], 148 | } 149 | } 150 | 151 | ### Check exit condition, if one team has no animals, return False 152 | found0 = False 153 | for temp_slot in t0: 154 | if not temp_slot.empty: 155 | found0 = True 156 | break 157 | found1 = False 158 | for temp_slot in t1: 159 | if not temp_slot.empty: 160 | found1 = True 161 | break 162 | if not found0: 163 | return False 164 | if not found1: 165 | return False 166 | 167 | teams = [t0, t1] 168 | for temp_phase in phase_dict[attack_str]: 169 | if temp_phase == "phase_hurt_and_faint_k": 170 | ### This is checked in phase_knockout for recursive Rhino behavior 171 | continue 172 | 173 | battle_phase( 174 | self, temp_phase, teams, self.pet_priority, phase_dict[attack_str] 175 | ) 176 | 177 | self.battle_history.update(phase_dict) 178 | 179 | ### Check if battle is over 180 | status = self.check_battle_result() 181 | if status < 0: 182 | return True 183 | else: 184 | ### End battle 185 | return False 186 | 187 | def check_battle_result(self): 188 | t0 = self.t0 189 | t1 = self.t1 190 | found0 = False 191 | for temp_slot in t0: 192 | if not temp_slot.empty: 193 | if temp_slot.pet.health > 0: 194 | found0 = True 195 | break 196 | found1 = False 197 | for temp_slot in t1: 198 | if not temp_slot.empty: 199 | if temp_slot.pet.health > 0: 200 | found1 = True 201 | break 202 | if found0 and found1: 203 | ### Fight not over 204 | return -1 205 | if found0: 206 | ### t0 won 207 | return 0 208 | if found1: 209 | ### t1 won 210 | return 1 211 | ### Must have been draw 212 | return 2 213 | 214 | @staticmethod 215 | def update_pet_priority(t0, t1): 216 | """ 217 | 218 | Prepares the order that the animals effects should be considered in 219 | 220 | Note that effects are performed in the order of highest attack to lowest 221 | attack. If there is a tie, then health values are compared. If there is 222 | a tie then a random animal is chosen first. 223 | 224 | """ 225 | ### Build all data types to determine effect order 226 | attack = [x.attack for x in t0] + [x.attack for x in t1] 227 | health = [x.health for x in t0] + [x.health for x in t1] 228 | teams = [0 for _ in t0] + [1 for _ in t1] 229 | idx = [x for x in range(5)] + [x for x in range(5)] 230 | 231 | for iter_idx, value in enumerate(attack): 232 | if value == "none": 233 | attack[iter_idx] = 0 234 | health[iter_idx] = 0 235 | 236 | ### Basic sorting by max attack 237 | # sort_idx = np.argsort(attack)[::-1] 238 | # attack = np.array([attack[x] for x in sort_idx]) 239 | # health = np.array([health[x] for x in sort_idx]) 240 | # teams = np.array([teams[x] for x in sort_idx]) 241 | # idx = np.array([idx[x] for x in sort_idx]) 242 | sort_idx = np.arange(0, len(attack)) 243 | attack = np.array(attack) 244 | health = np.array(attack) 245 | teams = np.array(teams) 246 | idx = np.array(idx) 247 | 248 | ### Find attack collisions 249 | uniquea = np.unique(attack)[::-1] 250 | start_idx = 0 251 | for uattack in uniquea: 252 | ### Get collision idx 253 | temp_idx = np.where(attack == uattack)[0] 254 | temp_attack = attack[temp_idx] 255 | 256 | ### Initialize final idx for sorting 257 | temp_sort_idx = np.arange(0, len(temp_idx)) 258 | 259 | if len(temp_idx) < 2: 260 | end_idx = start_idx + len(temp_idx) 261 | sort_idx[start_idx:end_idx] = temp_idx 262 | start_idx = end_idx 263 | continue 264 | 265 | ### Correct attack collisions by adding in health 266 | temp_health = health[temp_idx] 267 | temp_stats = temp_attack + temp_health 268 | temp_start_idx = 0 269 | for ustats in np.unique(temp_stats)[::-1]: 270 | temp_sidx = np.where(temp_stats == ustats)[0] 271 | temp_sidx = np.random.choice( 272 | temp_sidx, size=(len(temp_sidx),), replace=False 273 | ) 274 | temp_end_idx = temp_start_idx + len(temp_sidx) 275 | temp_sort_idx[temp_start_idx:temp_end_idx] = temp_sidx 276 | temp_start_idx = temp_end_idx 277 | 278 | ### Double check algorithm 279 | sorted_attack = [temp_attack[x] for x in temp_sort_idx] 280 | sorted_health = [temp_health[x] for x in temp_sort_idx] 281 | for iter_idx, tempa in enumerate(sorted_attack[1:-1]): 282 | iter_idx += 1 283 | if tempa < sorted_attack[iter_idx]: 284 | raise Exception("That's impossible. Sorting issue.") 285 | for iter_idx, temph in enumerate(sorted_health[1:-1]): 286 | iter_idx += 1 287 | if temph < sorted_health[iter_idx]: 288 | raise Exception("That's impossible. Sorting issue.") 289 | 290 | ### Dereference temp_sort_idx and store in sort_idx 291 | end_idx = start_idx + len(temp_idx) 292 | sort_idx[start_idx:end_idx] = temp_idx 293 | start_idx = end_idx 294 | 295 | ### Finish sorting by max attack 296 | attack = np.array([attack[x] for x in sort_idx]) 297 | teams = np.array([teams[x] for x in sort_idx]) 298 | idx = np.array([idx[x] for x in sort_idx]) 299 | 300 | ### Double check sorting algorithm 301 | for iter_idx, tempa in enumerate(attack[1:-1]): 302 | iter_idx += 2 303 | if tempa < attack[iter_idx]: 304 | raise Exception("That's impossible. Sorting issue.") 305 | 306 | ### Build final queue 307 | pet_priority = [] 308 | for t, i in zip(teams, idx): 309 | if [t0, t1][t][i].empty: 310 | continue 311 | pet_priority.append((t, i)) 312 | 313 | return pet_priority 314 | 315 | 316 | class RBattle(Battle): 317 | """ 318 | This class will calculate all possible outcomes of a SAP battle considering 319 | all paths of random behavior. The advantage is that probabilities of winning 320 | are evaluated exactly rather than requiring bootstrapped probabilities. 321 | 322 | Disadvantage is that it is possible that huge number of paths must be 323 | evaluated to determine exact probabilities. Protection against (could) be 324 | implemented in two ways: 325 | 1. Determining that paths lead to nomial identical results and can 326 | merge back together improving calculation efficiency 327 | 2. Define a maximum path size and if the number paths detected is larger 328 | then probabilities are bootstrapped. 329 | 330 | """ 331 | 332 | def __init__(self, t0, t1, max_paths=1000): 333 | """ 334 | Performs the battle between the input teams t1 and t2. 335 | 336 | """ 337 | super(RBattle, self).__init__(t0, t1) 338 | ### Make copy each team to cary out the battle so that the original 339 | ### pets are not modified in any way after the battle 340 | self.t0 = t0.copy() 341 | self.t0._battle = True 342 | self.t1 = t1.copy() 343 | self.t1._battle = True 344 | 345 | ### Internal storage 346 | self.battle_list = [] 347 | 348 | ### Build initial effect queue order 349 | self.pet_priority = self.update_pet_priority(self.t0, self.t1) 350 | 351 | raise NotImplementedError 352 | 353 | 354 | def battle_phase(battle_obj, phase, teams, pet_priority, phase_dict): 355 | """ 356 | Definition for performing all effects and actions throughout the battle. 357 | Implemented as function instead of class method to save an extra 358 | indentation. 359 | s 360 | """ 361 | 362 | ##### Trigger logic for starting battle 363 | if phase.startswith("phase_move"): 364 | start_order = [[str(x) for x in teams[0]], [str(x) for x in teams[1]]] 365 | teams[0].move_forward() 366 | teams[1].move_forward() 367 | end_order = [[str(x) for x in teams[0]], [str(x) for x in teams[1]]] 368 | phase_dict[phase] = [start_order, end_order] 369 | 370 | elif phase == "phase_start": 371 | battle_phase_start(battle_obj, phase, teams, pet_priority, phase_dict) 372 | 373 | ##### Trigger logic for an attack 374 | elif phase == "phase_attack_before": 375 | battle_phase_attack_before(battle_obj, phase, teams, pet_priority, phase_dict) 376 | 377 | elif phase == "phase_attack": 378 | ### Check if fainted and performed fainted triggers 379 | battle_phase_attack(battle_obj, phase, teams, pet_priority, phase_dict) 380 | 381 | elif phase == "phase_attack_after": 382 | battle_phase_attack_after(battle_obj, phase, teams, pet_priority, phase_dict) 383 | 384 | elif "phase_hurt_and_faint" in phase: 385 | battle_phase_hurt_and_faint(battle_obj, phase, teams, pet_priority, phase_dict) 386 | 387 | elif phase == "phase_knockout": 388 | battle_phase_knockout(battle_obj, phase, teams, pet_priority, phase_dict) 389 | 390 | else: 391 | raise Exception(f"Phase {phase} not found") 392 | 393 | 394 | def append_phase_list(phase_list, p, team_idx, pet_idx, activated, targets, possible): 395 | if activated: 396 | tiger = False 397 | if len(targets) > 0: 398 | if type(targets[0]) == list: 399 | tiger = True 400 | func = get_effect_function(p) 401 | 402 | if not tiger: 403 | phase_list.append( 404 | ( 405 | func.__name__, 406 | (team_idx, pet_idx), 407 | (p.__repr__()), 408 | [str(x) for x in targets], 409 | ) 410 | ) 411 | else: 412 | for temp_target in targets: 413 | phase_list.append( 414 | ( 415 | func.__name__, 416 | (team_idx, pet_idx), 417 | (p.__repr__()), 418 | [str(x) for x in temp_target], 419 | ) 420 | ) 421 | 422 | 423 | def check_summon_triggers( 424 | phase_list, p, team_idx, pet_idx, fteam, activated, targets, possible 425 | ): 426 | if not activated: 427 | return 0 428 | 429 | func = get_effect_function(p) 430 | if func not in [RespawnPet, SummonPet, SummonRandomPet]: 431 | return 0 432 | 433 | if "team" in p.ability["effect"]: 434 | team = p.ability["effect"]["team"] 435 | if team == "Enemy": 436 | return 0 437 | 438 | ### Otherwise, summon triggers need to be checked for each Pet in targets 439 | if len(targets) > 0: 440 | if type(targets[0]) == list: 441 | temp_all_targets = [] 442 | for entry in targets: 443 | temp_all_targets += entry 444 | targets = temp_all_targets 445 | 446 | for temp_te in targets: 447 | for temp_slot in fteam: 448 | temp_pet = temp_slot.pet 449 | tempa, tempt, tempp = temp_pet.friend_summoned_trigger(temp_te) 450 | append_phase_list( 451 | phase_list, temp_pet, team_idx, pet_idx, tempa, tempt, tempp 452 | ) 453 | 454 | return len(targets) 455 | 456 | 457 | def check_self_summoned_triggers(teams, pet_priority, phase_dict): 458 | """ 459 | Currently only butterfly 460 | 461 | """ 462 | 463 | phase_list = phase_dict["phase_start"] 464 | pp = pet_priority 465 | for team_idx, pet_idx in pp: 466 | p = teams[team_idx][pet_idx].pet 467 | if p.health <= 0: 468 | continue 469 | if p.ability["trigger"] != "Summoned": 470 | continue 471 | if p.ability["triggeredBy"]["kind"] != "Self": 472 | continue 473 | 474 | func = get_effect_function(p) 475 | target = func(p, [0, pet_idx], teams, te=p) 476 | append_phase_list(phase_list, p, team_idx, pet_idx, True, target, [target]) 477 | 478 | 479 | def check_status_triggers(phase_list, p, team_idx, pet_idx, teams): 480 | if p.status not in ["status-honey-bee", "status-extra-life"]: 481 | return 482 | 483 | ability = data["statuses"][p.status]["ability"] 484 | p.set_ability(ability) 485 | te_idx = [team_idx, pet_idx] 486 | activated, targets, possible = p.faint_trigger(p, te_idx) 487 | append_phase_list(phase_list, p, team_idx, pet_idx, activated, targets, possible) 488 | check_summon_triggers( 489 | phase_list, p, team_idx, pet_idx, teams[team_idx], activated, targets, possible 490 | ) 491 | 492 | 493 | def battle_phase_start(battle_obj, phase, teams, pet_priority, phase_dict): 494 | phase_list = phase_dict["phase_start"] 495 | pp = pet_priority 496 | for team_idx, pet_idx in pp: 497 | p = teams[team_idx][pet_idx].pet 498 | fteam, oteam = get_teams([team_idx, pet_idx], teams) 499 | activated, targets, possible = p.sob_trigger(oteam) 500 | append_phase_list( 501 | phase_list, p, team_idx, pet_idx, activated, targets, possible 502 | ) 503 | 504 | check_self_summoned_triggers(teams, pet_priority, phase_dict) 505 | 506 | return phase_list 507 | 508 | 509 | def battle_phase_hurt_and_faint(battle_obj, phase, teams, pet_priority, phase_dict): 510 | phase_list = phase_dict[phase] 511 | pp = pet_priority 512 | status_list = [] 513 | while True: 514 | ### Get a list of fainted pets 515 | fainted_list = [] 516 | for team_idx, pet_idx in pp: 517 | p = teams[team_idx][pet_idx].pet 518 | if p.name == "pet-none": 519 | continue 520 | if p.health <= 0: 521 | fainted_list.append([team_idx, pet_idx]) 522 | if p.status != "none": 523 | status_list.append([p, team_idx, pet_idx]) 524 | 525 | ### Check every fainted pet 526 | faint_targets_list = [] 527 | for team_idx, pet_idx in fainted_list: 528 | fteam, oteam = get_teams([team_idx, pet_idx], teams) 529 | fainted_pet = fteam[pet_idx].pet 530 | ### Check for all pets that trigger off this fainted pet (including self) 531 | for te_team_idx, te_pet_idx in pp: 532 | other_pet = teams[te_team_idx][te_pet_idx].pet 533 | te_idx = [te_team_idx, te_pet_idx] 534 | activated, targets, possible = other_pet.faint_trigger( 535 | fainted_pet, te_idx, oteam 536 | ) 537 | if activated: 538 | faint_targets_list.append( 539 | [ 540 | fainted_pet, 541 | te_team_idx, 542 | te_pet_idx, 543 | activated, 544 | targets, 545 | possible, 546 | ] 547 | ) 548 | append_phase_list( 549 | phase_list, 550 | other_pet, 551 | te_team_idx, 552 | te_pet_idx, 553 | activated, 554 | targets, 555 | possible, 556 | ) 557 | 558 | ### If no trigger was activated, then the pet was never removed. 559 | ### Check to see if it should be removed now. 560 | if teams[team_idx].check_friend(fainted_pet): 561 | teams[team_idx].remove(fainted_pet) 562 | ### Add this info to phase list 563 | phase_list.append( 564 | ("Fainted", (team_idx, pet_idx), (fainted_pet.__repr__()), [""]) 565 | ) 566 | 567 | ### If pet was summoned, then need to check for summon triggers 568 | for ( 569 | fainted_pet, 570 | team_idx, 571 | pet_idx, 572 | activated, 573 | targets, 574 | possible, 575 | ) in faint_targets_list: 576 | fteam, _ = get_teams([team_idx, pet_idx], teams) 577 | check_summon_triggers( 578 | phase_list, 579 | fainted_pet, 580 | team_idx, 581 | pet_idx, 582 | fteam, 583 | activated, 584 | targets, 585 | possible, 586 | ) 587 | 588 | ### If pet was hurt, then need to check for hurt triggers 589 | hurt_list = [] 590 | for team_idx, pet_idx in pp: 591 | fteam, oteam = get_teams([team_idx, pet_idx], teams) 592 | p = fteam[pet_idx].pet 593 | while p._hurt > 0: 594 | hurt_list.append([team_idx, pet_idx]) 595 | activated, targets, possible = p.hurt_trigger(oteam) 596 | append_phase_list( 597 | phase_list, p, team_idx, pet_idx, activated, targets, possible 598 | ) 599 | 600 | battle_obj.pet_priority = battle_obj.update_pet_priority( 601 | battle_obj.t0, battle_obj.t1 602 | ) 603 | pp = battle_obj.pet_priority 604 | 605 | ### If nothing happend, stop the loop 606 | if len(fainted_list) == 0 and len(hurt_list) == 0: 607 | break 608 | 609 | ### Check for status triggers on pet 610 | for p, team_idx, pet_idx in status_list: 611 | check_status_triggers(phase_list, p, team_idx, pet_idx, teams) 612 | 613 | return phase_list 614 | 615 | 616 | def battle_phase_attack_before(battle_obj, phase, teams, pet_priority, phase_dict): 617 | phase_list = phase_dict["phase_attack_before"] 618 | aidx, nidx = get_attack_idx(phase, teams, pet_priority, phase_dict) 619 | pp = pet_priority 620 | if len(aidx) != 2: 621 | ### Must be two animals available for attacking to continue with battle 622 | return phase_list 623 | for team_idx, pet_idx in pp: 624 | if aidx[team_idx][1] != pet_idx: 625 | ### Effects are only activated for the attacking pet 626 | continue 627 | p = teams[team_idx][pet_idx].pet 628 | fteam, oteam = get_teams([team_idx, pet_idx], teams) 629 | activated, targets, possible = p.before_attack_trigger(oteam) 630 | append_phase_list( 631 | phase_list, p, team_idx, pet_idx, activated, targets, possible 632 | ) 633 | 634 | return phase_dict 635 | 636 | 637 | def get_attack_idx(phase, teams, pet_priority, phase_dict): 638 | """ 639 | Helper function to get the current animals participating in the attack. 640 | These are defined as the first animals in each team that have a health above 641 | zero. 642 | """ 643 | ### Only check for the first target 644 | ### Ff there is no target it means the target fainted in the 'before_attack' phase 645 | if not teams[0][0].empty and teams[0][0].pet.health > 0: 646 | t0_idx = 0 647 | else: 648 | t0_idx = -1 649 | 650 | if not teams[1][0].empty and teams[1][0].pet.health > 0: 651 | t1_idx = 0 652 | else: 653 | t1_idx = -1 654 | 655 | ret_idx = [] 656 | if t0_idx > -1: 657 | ret_idx.append((0, t0_idx)) 658 | if t1_idx > -1: 659 | ret_idx.append((1, t1_idx)) 660 | 661 | ### Getting next idx at the same time 662 | t0_next_idx = -1 663 | for iter_idx, temp_slot in enumerate(teams[0]): 664 | if not temp_slot.empty: 665 | if temp_slot.pet.health > 0: 666 | if t0_idx == iter_idx: 667 | continue 668 | t0_next_idx = iter_idx 669 | break 670 | t1_next_idx = -1 671 | for iter_idx, temp_slot in enumerate(teams[1]): 672 | if not temp_slot.empty: 673 | if temp_slot.pet.health > 0: 674 | if t1_idx == iter_idx: 675 | continue 676 | t1_next_idx = iter_idx 677 | break 678 | ret_next_idx = [] 679 | if t0_next_idx > -1: 680 | ret_next_idx.append((0, t0_next_idx)) 681 | else: 682 | ret_next_idx.append(()) 683 | if t1_next_idx > -1: 684 | ret_next_idx.append((1, t1_next_idx)) 685 | else: 686 | ret_next_idx.append(()) 687 | 688 | return ret_idx, ret_next_idx 689 | 690 | 691 | def battle_phase_attack_after(battle_obj, phase, teams, pet_priority, phase_dict): 692 | phase_list = phase_dict[phase] 693 | pp = pet_priority 694 | 695 | #### Can get the two animals that just previously attacked from the 696 | #### phase_dict 697 | attack_history = phase_dict["phase_attack"] 698 | if len(attack_history) == 0: 699 | return phase_dict 700 | 701 | t0_pidx = attack_history[0][1][0] 702 | t1_pidx = attack_history[0][1][1] 703 | 704 | for team_idx, pet_idx in pp: 705 | ### Check if current pet is directly behind the pet that just attacked 706 | test_idx = [t0_pidx, t1_pidx][team_idx] + 1 707 | if pet_idx != test_idx: 708 | continue 709 | 710 | ### If it is, then the after_attack ability can be activated 711 | p = teams[team_idx][pet_idx].pet 712 | fteam, oteam = get_teams([team_idx, pet_idx], teams) 713 | activated, targets, possible = p.after_attack_trigger(oteam) 714 | append_phase_list( 715 | phase_list, p, team_idx, pet_idx, activated, targets, possible 716 | ) 717 | 718 | return phase_dict 719 | 720 | 721 | def battle_phase_knockout(battle_obj, phase, teams, pet_priority, phase_dict): 722 | phase_list = phase_dict[phase] 723 | 724 | #### Get knockout list from the end of the phase_attack info and remove 725 | #### the knockout list from phase attack 726 | attack_history = phase_dict["phase_attack"] 727 | if len(attack_history) == 0: 728 | return phase_dict 729 | knockout_list = attack_history[-1] 730 | phase_dict["phase_attack"] = phase_dict["phase_attack"][0:-1] 731 | 732 | for apet, team_idx in knockout_list: 733 | if apet.health > 0: 734 | ### Need to loop to handle Rhino 735 | fteam, oteam = get_teams([team_idx, 0], teams) 736 | current_length = 0 737 | while True: 738 | pet_idx = fteam.index(apet) 739 | activated, targets, possible = apet.knockout_trigger(oteam) 740 | append_phase_list( 741 | phase_list, apet, team_idx, pet_idx, activated, targets, possible 742 | ) 743 | 744 | if not activated: 745 | ### Easy breaking condition 746 | break 747 | 748 | battle_phase( 749 | battle_obj, 750 | "phase_hurt_and_faint_k", 751 | teams, 752 | pet_priority, 753 | phase_dict, 754 | ) 755 | 756 | if len(phase_dict["phase_hurt_and_faint_k"]) == current_length: 757 | ### No more recursion needed because nothing else fainted 758 | break 759 | else: 760 | ### Otherwise, something has been knockedout by Rhino 761 | ### ability and while loop should iterate again 762 | current_length = len(phase_dict["phase_hurt_and_faint_k"]) 763 | 764 | return phase_dict 765 | 766 | 767 | def get_attack(p0, p1): 768 | """Ugly but works""" 769 | attack_list = [p1.get_damage(p0.attack), p0.get_damage(p1.attack)] 770 | if p0.status in status.apply_once: 771 | p0.status = "none" 772 | if p1.status in status.apply_once: 773 | p1.status = "none" 774 | return attack_list 775 | 776 | 777 | def battle_phase_attack(battle_obj, phase, teams, pet_priority, phase_dict): 778 | phase_list = phase_dict["phase_attack"] 779 | aidx, nidx = get_attack_idx(phase, teams, pet_priority, phase_dict) 780 | if len(aidx) != 2: 781 | ### Must be two animals available for attacking to continue with battle 782 | return phase_list 783 | 784 | p0 = teams[0][aidx[0][1]].pet 785 | p1 = teams[1][aidx[1][1]].pet 786 | 787 | #### Implement food 788 | p0a, p1a = get_attack(p0, p1) 789 | 790 | teams[0][aidx[0][1]].pet.hurt(p1a) 791 | teams[1][aidx[1][1]].pet.hurt(p0a) 792 | phase_list.append(["Attack", (aidx[0]), str(p0), [str(p1)]]) 793 | 794 | ### Keep track of knockouts for rhino and hippo by: 795 | ### (attacking_pet, team_idx) 796 | knockout_list = [] 797 | if teams[0][aidx[0][1]].pet.health <= 0: 798 | knockout_list.append((p1, 1)) 799 | if teams[1][aidx[1][1]].pet.health <= 0: 800 | knockout_list.append((p0, 0)) 801 | 802 | ### Implement chili 803 | if p0.status == "status-splash-attack": 804 | original_attack = p0._attack 805 | original_tmp_attack = p0._until_end_of_battle_attack_buff 806 | original_status = p0.status 807 | p0._attack = 5 808 | p0._until_end_of_battle_attack_buff = 0 809 | if len(nidx[1]) != 0: 810 | pn1 = teams[1][nidx[1][1]].pet 811 | p0a, p1a = get_attack(p0, pn1) 812 | pn1.hurt(p0a) 813 | phase_list.append(["splash", (aidx[0]), (str(p0)), [str(pn1)]]) 814 | 815 | if pn1.health <= 0: 816 | knockout_list.append((p0, 0)) 817 | 818 | p0.status = original_status 819 | p0._attack = original_attack 820 | p0._until_end_of_battle_attack_buff = original_tmp_attack 821 | 822 | if p1.status == "status-splash-attack": 823 | original_attack = p1._attack 824 | original_tmp_attack = p1._until_end_of_battle_attack_buff 825 | original_status = p1.status 826 | p1._attack = 5 827 | p1._until_end_of_battle_attack_buff = 0 828 | if len(nidx[0]) != 0: 829 | pn0 = teams[0][nidx[0][1]].pet 830 | p0a, p1a = get_attack(pn0, p1) 831 | pn0.hurt(p1a) 832 | phase_list.append(["splash", (aidx[1]), (str(p1)), [str(pn0)]]) 833 | 834 | if pn0.health <= 0: 835 | knockout_list.append((p1, 1)) 836 | 837 | p1.status = original_status 838 | p1._attack = original_attack 839 | p1._until_end_of_battle_attack_buff = original_tmp_attack 840 | 841 | ### Add knockout list to the end of phase_list. This is later removed 842 | ### in the knockout phase 843 | phase_list.append(knockout_list) 844 | 845 | return phase_dict 846 | --------------------------------------------------------------------------------